diff --git a/wit/lit_module.py b/wit/lit_module.py index 3f8d2f1..be86bf1 100644 --- a/wit/lit_module.py +++ b/wit/lit_module.py @@ -97,4 +97,5 @@ class LitModule(pl.LightningModule): mode="max", stopping_threshold=1, ) - return [checkpoint_callback, early_stop_callback] + return [checkpoint_callback] + # return [checkpoint_callback, early_stop_callback] diff --git a/wit/lit_train.py b/wit/lit_train.py index 3310674..6596fa0 100644 --- a/wit/lit_train.py +++ b/wit/lit_train.py @@ -20,33 +20,36 @@ from tokenization_qwen import QWenTokenizer model_name = "qwen/Qwen-1_8B-Chat" learning_rate = 0.0001 use_tril_attention_mask = None -precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true" +precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" tokenizer_name_or_path = None dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki"] dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki/20230198/58.jsonl.gz"] -train_batch_size = 1 -val_batch_size = 1 -accumulate_grad_batches = 32 +train_batch_size = 256 +val_batch_size = 16 +limit_val_batches = 128 num_proc = 8 -max_epochs = None +max_epochs = 1000 strategy = "fsdp" resume_from_ckpt_path = None seed = 42 class SpecialDataset(Dataset): - def __init__(self, size=4096): + def __init__(self, size=65536): self.size = size self.features = [] + a = torch.randint(0, 1024, [size]) + self.data = torch.stack([a, a * 2, a * 3, a * 4]).permute(1, 0) def __len__(self): return self.size def __getitem__(self, idx): output = {} - output["input_ids"] = torch.randint(0, 4096, [128]) - output["labels"] = output["input_ids"] - output["token_type_ids"] = torch.zeros([128]) + data = self.data[idx] + output["input_ids"] = data + output["labels"] = data + output["token_type_ids"] = torch.zeros(data.shape) return output @@ -144,7 +147,6 @@ if __name__ == "__main__": train_dataset = SpecialDataset() val_dataset = SpecialDataset() - # dataloaders train_dataloader = DataLoader( train_dataset, batch_size=train_batch_size, @@ -159,21 +161,17 @@ if __name__ == "__main__": num_workers=num_proc, collate_fn=DefaultDataCollator(), persistent_workers=True, + shuffle=True, ) - ne = next(train_dataloader._get_iterator()) - - # trainer - # apply_all_patches() torch.set_float32_matmul_precision("medium") precision = precision lit_trainer = pl.Trainer( accelerator="gpu", precision=precision, - log_every_n_steps=5, - accumulate_grad_batches=accumulate_grad_batches, strategy=strategy, max_epochs=max_epochs, + limit_val_batches=limit_val_batches, ) lit_trainer.fit( lit_module,