From 8120be66a60023b0481e0cfa7347ad976090f2ff Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 26 Feb 2024 23:59:00 +0800 Subject: [PATCH] sperate train and val dataset. --- wit/lit_train.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/wit/lit_train.py b/wit/lit_train.py index 6596fa0..63f7296 100644 --- a/wit/lit_train.py +++ b/wit/lit_train.py @@ -26,7 +26,6 @@ dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki"] dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki/20230198/58.jsonl.gz"] train_batch_size = 256 val_batch_size = 16 -limit_val_batches = 128 num_proc = 8 max_epochs = 1000 strategy = "fsdp" @@ -35,10 +34,10 @@ seed = 42 class SpecialDataset(Dataset): - def __init__(self, size=65536): + def __init__(self, start, end, size=65536): self.size = size self.features = [] - a = torch.randint(0, 1024, [size]) + a = torch.randint(start, end, [size]) self.data = torch.stack([a, a * 2, a * 3, a * 4]).permute(1, 0) def __len__(self): @@ -144,8 +143,8 @@ if __name__ == "__main__": train_dataset = ConcatDataset(train_dataset_list) val_dataset = ConcatDataset(val_dataset_list) - train_dataset = SpecialDataset() - val_dataset = SpecialDataset() + train_dataset = SpecialDataset(0, 1000, 65536) + val_dataset = SpecialDataset(1000, 1024, 1024) train_dataloader = DataLoader( train_dataset, @@ -166,13 +165,7 @@ if __name__ == "__main__": torch.set_float32_matmul_precision("medium") precision = precision - lit_trainer = pl.Trainer( - accelerator="gpu", - precision=precision, - strategy=strategy, - max_epochs=max_epochs, - limit_val_batches=limit_val_batches, - ) + lit_trainer = pl.Trainer(accelerator="gpu", precision=precision, strategy=strategy, max_epochs=max_epochs) lit_trainer.fit( lit_module, train_dataloaders=train_dataloader,