From d10e7a83962f625208e9f5943f3cfe3a42f2af06 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 25 Mar 2024 19:53:11 +0800 Subject: [PATCH] Refine train.py for train. --- wit/train.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/wit/train.py b/wit/train.py index b77f96e..3109c33 100644 --- a/wit/train.py +++ b/wit/train.py @@ -17,7 +17,7 @@ pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" learning_rate = 0.0001 use_tril_attention_mask = None precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" -train_batch_size = 32 +train_batch_size = 16 val_batch_size = 32 num_proc = 8 max_epochs = 1000 @@ -39,7 +39,7 @@ if __name__ == "__main__": lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") - level_ratio = 4 + level_ratio = 6 start = vocab_size * level_ratio * level_ratio end = start * level_ratio size = end * level_ratio @@ -47,15 +47,15 @@ if __name__ == "__main__": train_dataset, val_dataset = raw_dataset.Split(0.95) train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size) val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size) - it = iter(train_dataloader) - print("data samples:") - for i in range(10): - print(next(it)["input_ids"].numpy().tolist()) + # it = iter(train_dataloader) + # print("data samples:") + # for i in range(10): + # print(next(it)["input_ids"].numpy().tolist()) torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer( - accelerator="gpu", - # devices=[0], + accelerator="cuda", + devices=[0, 1], precision=precision, logger=TBLogger("./", default_hp_metric=False), strategy=strategy,