diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index d905104..7b8098c 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -7,6 +7,7 @@ from itertools import chain from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np +from torch.utils.data import BatchSampler class MeaningMap: # 16777216 1048576 8192 @@ -26,103 +27,103 @@ class MeaningMap: # 16777216 1048576 8192 self.ms_data = np.load(file_data) self.ms_start = np.load(file_start) self.ms_len = np.load(file_len) - return None + else: + print("Disk cache miss, build new one.") - print("Disk cache miss, build new one.") + mm = np.empty((size, max_subitem), dtype=np.int32) - mm = np.empty((size, max_subitem), dtype=np.int32) - # total_level = int(math.log(size / vocab_size, max_subitem)) + index = np.arange(0, size) + mm = np.random.random((size, max_subitem)) - # start = [0] - # end = [vocab_size] - # shift = vocab_size - # for i in range(total_level): - # shift = end[-1] - # start.append(end[-1]) - # end.append(shift * self.max_subitem) - # start.append(end[-1]) - # end.append(size) + mask_zero = mm.copy() + mask_zero[:, 0] = 0.0 + mask_zero.sort(axis=1) + thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) + mask_zero = mask_zero > thre - index = np.arange(0, size) - mm = np.random.random((size, max_subitem)) + item_sum = mm.sum(axis=1) + scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) + mm = mm * scale + mm[mask_zero] = 0 - mask_zero = mm.copy() - mask_zero[:, 0] = 0.0 - mask_zero.sort(axis=1) - thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) - mask_zero = mask_zero > thre + mm[:vocab_size, 0] = np.arange(0, vocab_size) + mm[:vocab_size, 1:] = 0 + mm = mm.astype(np.int32) - item_sum = mm.sum(axis=1) - scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) - mm = mm * scale - mm[mask_zero] = 0 + ms = [] # meaning sequence + ms_start = [] # meaning sequence start + ms_len = [] # meaning sequence length + index = 0 + for i in range(self.vocab_size): + ms.append([i]) + ms_start.append(index) + ms_len.append(1) + index = index + 1 - mm[:vocab_size, 0] = np.arange(0, vocab_size) - mm[:vocab_size, 1:] = 0 - mm = mm.astype(np.int32) + for i in range(self.vocab_size, size): + m = mm[i] + m = m[m > 0] + ma = [] + for newm in m.tolist(): + ma = ma + ms[newm] + ms.append(ma) + ms_start.append(index) + ms_len.append(len(ma)) + index = index + len(ma) - ms = [] # meaning sequence - ms_start = [] # meaning sequence start - ms_len = [] # meaning sequence length - index = 0 - for i in range(self.vocab_size): - ms.append([i]) - ms_start.append(index) - ms_len.append(1) - index = index + 1 + ms_data = list(chain(*ms)) + np.save(file_data, np.array(ms_data).astype(np.int32)) + np.save(file_start, np.array(ms_start).astype(np.int32)) + np.save(file_len, np.array(ms_len).astype(np.int32)) - for i in range(self.vocab_size, size): - m = mm[i] - m = m[m > 0] - ma = [] - for newm in m.tolist(): - ma = ma + ms[newm] - ms.append(ma) - ms_start.append(index) - ms_len.append(len(ma)) - index = index + len(ma) - - ms_data = list(chain(*ms)) - np.save(file_data, np.array(ms_data).astype(np.int32)) - np.save(file_start, np.array(ms_start).astype(np.int32)) - np.save(file_len, np.array(ms_len).astype(np.int32)) - - self.ms_data = ms_data - self.ms_start = ms_start - self.ms_len = ms_len - print("Disk cache build end.") + self.ms_data = ms_data + self.ms_start = ms_start + self.ms_len = ms_len + print("Disk cache build end.") def GetSequence(self, meaning): start = self.ms_start[meaning] len = self.ms_len[meaning] return self.ms_data[start : start + len] + def MaxLength(self): + return max(self.ms_len) + class MeaningDataset(Dataset): - def __init__(self, start=131072, end=1048576, size=32768, vocab_size=4096, max_subitem=10, seed=42): - self.seed = seed + def __init__( + self, + start=131072, + end=1048576, + size=32768, + vocab_size=4096, + max_subitem=10, + min_seq_len=2, + seed=42, + data=None, + length=None, + ): + if data != None and length != None: + self.data = data + self.length = length + return np.random.seed(seed) - self.size = size - self.mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576 + mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576 self.data = [] + self.length = [] meanings = np.random.randint(start, end, size=(size)) for m in meanings: - sq = self.mm.GetSequence(m) - if len(sq) > 1: + sq = mm.GetSequence(m) + if len(sq) >= min_seq_len: self.data.append(sq) - left = size - len(self.data) - while True: - if left <= 0: - break - index = np.random.randint(start, end) - sq = self.mm.GetSequence(index) - if len(sq) > 1: - self.data.append(sq) - left = left - 1 + self.length.append(len(sq)) def __len__(self): - return self.size + return len(self.data) + + def len(self): + return len(self.data) def __getitem__(self, idx): output = {} @@ -132,11 +133,93 @@ class MeaningDataset(Dataset): output["token_type_ids"] = torch.zeros(data.shape) return output + def GetBatch(self, index_list): + data = [] + for i in index_list: + data.append(self.data[i]) + output = {} + data = torch.tensor(data).long() + output["input_ids"] = data + output["labels"] = data.clone() + output["token_type_ids"] = torch.zeros(data.shape) + return output + + def Split(self, ratio): + l = len(self.data) + middle = int(l * ratio) + d_shuffle = self.data.copy() + l_shuffle = self.length.copy() + md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle]) + md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:]) + return md1, md2 + + +class BatchGroupMeaningDataloader(Dataset): + + def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + length = dataset.length + unique, counts = np.unique(length, return_counts=True) + gl = {} + for u in unique: + gl[u] = np.where(length == u)[0] + + lens = list(gl.keys()) + gs = {} + if shuffle: + for k in gl.keys(): + sl = gl[k].copy() + np.random.shuffle(sl) + gs[k] = sl + else: + for k in gl.keys(): + sl = gl[k].copy() + gs[k] = sl + + index = np.zeros((0, batch_size), dtype=np.int64) + for l in lens: + batch = len(gs[l]) // batch_size + new = gs[l][0 : batch * batch_size].reshape(batch, batch_size) + index = np.concatenate((index, new), axis=0) + if shuffle: + index_shuffle = np.arange(0, index.shape[0]) + np.random.shuffle(index_shuffle) + index = index[index_shuffle] + self.index = index + + def __len__(self): + return len(self.index) + + def __getitem__(self, idx): + # print("get idx" + str(idx)) + return self.dataset.GetBatch(self.index[idx]) + if __name__ == "__main__": - md = MeaningDataset(4096, 4100, size=32768) - it = iter(md) + md = MeaningDataset(4096, 8100, size=1024) + train, val = md.Split(0.95) + + dl = BatchGroupMeaningDataloader(train, 2) + it = iter(dl) + ne1 = next(it) + ne2 = next(it) + ne3 = next(it) + + dl = DataLoader( + train, + num_workers=1, + persistent_workers=True, + shuffle=False, + ) + it = iter(dl) + ne1 = next(it) + ne2 = next(it) + ne3 = next(it) + for i in range(10): daf = next(it)["input_ids"].numpy().tolist() diff --git a/wit/train.py b/wit/train.py index 9d94236..b77f96e 100644 --- a/wit/train.py +++ b/wit/train.py @@ -3,25 +3,22 @@ from functools import partial from itertools import chain from typing import Dict, Tuple -import datasets import pytorch_lightning as pl import torch -from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset from lit_module import LitModule from tokenization_qwen import QWenTokenizer from logger import TBLogger -from special_dataset import SpecialDataset -from meaning_dataset import MeaningDataset +from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader from wit.configuration import ModelConfig pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" learning_rate = 0.0001 use_tril_attention_mask = None precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" -train_batch_size = 1 -val_batch_size = 1 +train_batch_size = 32 +val_batch_size = 32 num_proc = 8 max_epochs = 1000 strategy = "auto" @@ -42,38 +39,19 @@ if __name__ == "__main__": lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") - # raw_dataset = SpecialDataset() - - level_scale = 4 - start = vocab_size * level_scale * level_scale - raw_dataset = MeaningDataset( - start=start, - end=start * level_scale, - size=start * level_scale * level_scale, - max_subitem=level_scale, - vocab_size=vocab_size, - ) - - train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) - it = iter(train_dataset) + level_ratio = 4 + start = vocab_size * level_ratio * level_ratio + end = start * level_ratio + size = end * level_ratio + raw_dataset = MeaningDataset(start, end, size, vocab_size, level_ratio) + train_dataset, val_dataset = raw_dataset.Split(0.95) + train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size) + val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size) + it = iter(train_dataloader) print("data samples:") for i in range(10): print(next(it)["input_ids"].numpy().tolist()) - train_dataloader = DataLoader( - train_dataset, - batch_size=train_batch_size, - num_workers=num_proc, - persistent_workers=True, - shuffle=True, - ) - val_dataloader = DataLoader( - val_dataset, - batch_size=val_batch_size, - num_workers=num_proc, - persistent_workers=True, - ) - torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer( accelerator="gpu", diff --git a/wit/train_special.py b/wit/train_special.py new file mode 100644 index 0000000..65d3775 --- /dev/null +++ b/wit/train_special.py @@ -0,0 +1,79 @@ +import argparse +from functools import partial +from itertools import chain +from typing import Dict, Tuple + +import datasets +import pytorch_lightning as pl +import torch +from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset + +from lit_module import LitModule +from tokenization_qwen import QWenTokenizer +from logger import TBLogger + +from special_dataset import SpecialDataset +from meaning_dataset import MeaningDataset +from wit.configuration import ModelConfig + +pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" +learning_rate = 0.0001 +use_tril_attention_mask = None +precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" +train_batch_size = 128 +val_batch_size = 128 +num_proc = 8 +max_epochs = 1000 +strategy = "auto" +resume_from_ckpt_path = None +seed = 42 +vocab_size = 256 + + +if __name__ == "__main__": + torch.manual_seed(seed) + + config = ModelConfig() + config.vocab_size = vocab_size + config.hidden_size = 128 # 128 1024 2048 32 + config.num_hidden_layers = 3 # 6 12 24 3 + config.num_attention_heads = 8 # 8 8 16 + + lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) + tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") + + raw_dataset = SpecialDataset() + train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) + it = iter(train_dataset) + print("data samples:") + for i in range(10): + print(next(it)["input_ids"].numpy().tolist()) + + train_dataloader = DataLoader( + train_dataset, + batch_size=train_batch_size, + num_workers=num_proc, + persistent_workers=True, + shuffle=True, + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=val_batch_size, + num_workers=num_proc, + persistent_workers=True, + ) + + torch.set_float32_matmul_precision("medium") + lit_trainer = pl.Trainer( + accelerator="gpu", + precision=precision, + logger=TBLogger("./", default_hp_metric=False), + strategy=strategy, + max_epochs=max_epochs, + ) + lit_trainer.fit( + lit_module, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ckpt_path=resume_from_ckpt_path, + )