From 9feaafcb7af9a9d3778197f4b0070fba0edf1b86 Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 15 Mar 2024 11:16:13 +0800 Subject: [PATCH] Apply meaning data train. --- wit/meaning_dataset.py | 33 ++++++++++++++++++++++----------- wit/train.py | 34 ++++++++++++++++++---------------- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index b7aeb4a..d905104 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -31,17 +31,17 @@ class MeaningMap: # 16777216 1048576 8192 print("Disk cache miss, build new one.") mm = np.empty((size, max_subitem), dtype=np.int32) - total_level = int(math.log(size / vocab_size, max_subitem)) + # total_level = int(math.log(size / vocab_size, max_subitem)) - start = [0] - end = [vocab_size] - shift = vocab_size - for i in range(total_level): - shift = end[-1] - start.append(end[-1]) - end.append(shift * self.max_subitem) - start.append(end[-1]) - end.append(size) + # start = [0] + # end = [vocab_size] + # shift = vocab_size + # for i in range(total_level): + # shift = end[-1] + # start.append(end[-1]) + # end.append(shift * self.max_subitem) + # start.append(end[-1]) + # end.append(size) index = np.arange(0, size) mm = np.random.random((size, max_subitem)) @@ -108,7 +108,18 @@ class MeaningDataset(Dataset): self.data = [] meanings = np.random.randint(start, end, size=(size)) for m in meanings: - self.data.append(self.mm.GetSequence(m)) + sq = self.mm.GetSequence(m) + if len(sq) > 1: + self.data.append(sq) + left = size - len(self.data) + while True: + if left <= 0: + break + index = np.random.randint(start, end) + sq = self.mm.GetSequence(index) + if len(sq) > 1: + self.data.append(sq) + left = left - 1 def __len__(self): return self.size diff --git a/wit/train.py b/wit/train.py index b54ccd9..9d94236 100644 --- a/wit/train.py +++ b/wit/train.py @@ -8,10 +8,6 @@ import pytorch_lightning as pl import torch from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset -from transformers import ( - DefaultDataCollator, - set_seed, -) from lit_module import LitModule from tokenization_qwen import QWenTokenizer from logger import TBLogger @@ -24,7 +20,7 @@ pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" learning_rate = 0.0001 use_tril_attention_mask = None precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" -train_batch_size = 256 +train_batch_size = 1 val_batch_size = 1 num_proc = 8 max_epochs = 1000 @@ -35,23 +31,30 @@ vocab_size = 256 if __name__ == "__main__": - - set_seed(seed) + torch.manual_seed(seed) config = ModelConfig() config.vocab_size = vocab_size - config.hidden_size = 128 # 128 1024 2048 32 - config.num_hidden_layers = 6 # 6 12 24 3 - config.num_attention_heads = 8 # 8 8 16 + config.hidden_size = 1024 # 128 1024 2048 32 + config.num_hidden_layers = 12 # 6 12 24 3 + config.num_attention_heads = 16 # 8 8 16 lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) - tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") - raw_dataset = SpecialDataset() - # raw_dataset = MeaningDataset(start=65536, end=262133, size=32768, max_subitem=4, vocab_size=vocab_size) - train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) + # raw_dataset = SpecialDataset() + level_scale = 4 + start = vocab_size * level_scale * level_scale + raw_dataset = MeaningDataset( + start=start, + end=start * level_scale, + size=start * level_scale * level_scale, + max_subitem=level_scale, + vocab_size=vocab_size, + ) + + train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) it = iter(train_dataset) print("data samples:") for i in range(10): @@ -61,7 +64,6 @@ if __name__ == "__main__": train_dataset, batch_size=train_batch_size, num_workers=num_proc, - collate_fn=DefaultDataCollator(), persistent_workers=True, shuffle=True, ) @@ -69,13 +71,13 @@ if __name__ == "__main__": val_dataset, batch_size=val_batch_size, num_workers=num_proc, - collate_fn=DefaultDataCollator(), persistent_workers=True, ) torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer( accelerator="gpu", + # devices=[0], precision=precision, logger=TBLogger("./", default_hp_metric=False), strategy=strategy,