From 9ef3e92b2318933c0b1aa6fc58faa57774555734 Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 5 Mar 2024 22:09:28 +0800 Subject: [PATCH] Try model train. --- wit/lit_train.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/wit/lit_train.py b/wit/lit_train.py index a958a97..03c46a2 100644 --- a/wit/lit_train.py +++ b/wit/lit_train.py @@ -17,6 +17,7 @@ from transformers import ( from modelscope import snapshot_download from lit_module import LitModule from tokenization_qwen import QWenTokenizer +from logger import TBLogger model_name = "qwen/Qwen-1_8B-Chat" learning_rate = 0.0001 @@ -27,20 +28,22 @@ train_batch_size = 256 val_batch_size = 16 num_proc = 8 max_epochs = 1000 -strategy = "fsdp" +strategy = "auto" resume_from_ckpt_path = None seed = 42 +vocab_size = 4096 class SpecialDataset(Dataset): - def __init__(self, start=1, end=4096, size=65536): + def __init__(self, start=1, end=320, size=32768): self.size = size self.features = [] a = torch.randint(start, end, [size]) b = torch.randint(start, end, [size]) c = torch.randint(start, end, [size]) d = torch.randint(start, end, [size]) - self.data = torch.stack([a, b, c, d, ((a + b + c + d) / 4).long()]).permute(1, 0) + # self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0) + self.data = torch.stack([a, a + a, a + a, a + a]).permute(1, 0) def __len__(self): return self.size @@ -50,7 +53,8 @@ class SpecialDataset(Dataset): data = self.data[idx] output["input_ids"] = data output["labels"] = data.clone() - output["labels"][:4] = 0 + # output["labels"][:2] = 0 + # output["labels"][:2] = vocab_size output["token_type_ids"] = torch.zeros(data.shape) return output @@ -67,10 +71,7 @@ if __name__ == "__main__": tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") - raw_dataset = SpecialDataset() - train_idx, val_idx = random_split(list(range(len(raw_dataset))), [0.95, 0.05]) - train_dataset = Subset(raw_dataset, train_idx.indices) - val_dataset = Subset(raw_dataset, val_idx.indices) + train_dataset, val_dataset = random_split(SpecialDataset(), [0.95, 0.05]) train_dataloader = DataLoader( train_dataset, @@ -89,8 +90,13 @@ if __name__ == "__main__": ) torch.set_float32_matmul_precision("medium") - precision = precision - lit_trainer = pl.Trainer(accelerator="gpu", precision=precision, strategy=strategy, max_epochs=max_epochs) + lit_trainer = pl.Trainer( + accelerator="gpu", + precision=precision, + logger=TBLogger("./", default_hp_metric=False), + strategy=strategy, + max_epochs=max_epochs, + ) lit_trainer.fit( lit_module, train_dataloaders=train_dataloader,