diff --git a/.gitignore b/.gitignore index 5968d55..2a40627 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ __pycache__ .vscode *.txt +*.npy temp # lightning_logs diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py new file mode 100644 index 0000000..ec4071c --- /dev/null +++ b/wit/meaning_dataset.py @@ -0,0 +1,132 @@ +import os +import datasets +import torch +import math +import random +from itertools import chain +from typing import Dict, Tuple +from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split +import numpy as np + + +class MeaningMap: # 16777216 1048576 8192 + + def __init__(self, size=1048576, vocab_size=4096, max_subitem=10): + self.size = size + self.vocab_size = vocab_size + self.max_subitem = max_subitem + + file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) + file_start = file + "_start" + ".npy" + file_len = file + "_len" + ".npy" + file_data = file + "_data" + ".npy" + + if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data): + print("Load from disk cache: " + file) + self.ms_data = np.load(file_data) + self.ms_start = np.load(file_start) + self.ms_len = np.load(file_len) + return None + + print("Disk cache miss, build new one.") + + mm = np.empty((size, max_subitem), dtype=np.int32) + total_level = int(math.log(size / vocab_size, max_subitem)) + + start = [0] + end = [vocab_size] + shift = vocab_size + for i in range(total_level): + shift = end[-1] + start.append(end[-1]) + end.append(shift * self.max_subitem) + start.append(end[-1]) + end.append(size) + + index = np.arange(0, size) + mm = np.random.random((size, max_subitem)) + + mask_zero = mm.copy() + mask_zero[:, 0] = 0.0 + mask_zero.sort(axis=1) + thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) + mask_zero = mask_zero > thre + + item_sum = mm.sum(axis=1) + scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) + mm = mm * scale + mm[mask_zero] = 0 + + mm[:vocab_size, 0] = np.arange(0, vocab_size) + mm[:vocab_size, 1:] = 0 + mm = mm.astype(np.int32) + + ms = [] # meaning sequence + ms_start = [] # meaning sequence start + ms_len = [] # meaning sequence length + index = 0 + for i in range(self.vocab_size): + ms.append([i]) + ms_start.append(index) + ms_len.append(1) + index = index + 1 + + for i in range(self.vocab_size, size): + m = mm[i] + m = m[m > 0] + ma = [] + for newm in m.tolist(): + ma = ma + ms[newm] + ms.append(ma) + ms_start.append(index) + ms_len.append(len(ma)) + index = index + len(ma) + + ms_data = list(chain(*ms)) + np.save(file_data, np.array(ms_data).astype(np.int32)) + np.save(file_start, np.array(ms_start).astype(np.int32)) + np.save(file_len, np.array(ms_len).astype(np.int32)) + + self.ms_data = ms_data + self.ms_start = ms_start + self.ms_len = ms_len + print("Disk cache build end.") + + def GetSequence(self, meaning): + start = self.ms_start[meaning] + len = self.ms_len[meaning] + return self.ms_data[start : start + len] + + +class MeaningDataset(Dataset): + + def __init__(self, start=131072, end=1048576, size=32768, vocab_size=4096, max_subitem=10, seed=42): + self.seed = seed + np.random.seed(seed) + self.size = size + self.mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576 + self.data = [] + meanings = np.random.randint(start, end, size=(size)) + for m in meanings: + self.data.append(self.mm.GetSequence(m)) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + output = {} + data = torch.tensor(self.data[idx]) + output["input_ids"] = data + output["labels"] = data.clone() + output["token_type_ids"] = torch.zeros(data.shape) + return output + + +if __name__ == "__main__": + + md = MeaningDataset(4096, 4100, size=32768) + it = iter(md) + for i in range(10): + daf = next(it)["input_ids"].numpy().tolist() + + print(daf) diff --git a/wit/special_dataset.py b/wit/special_dataset.py new file mode 100644 index 0000000..c3ab154 --- /dev/null +++ b/wit/special_dataset.py @@ -0,0 +1,49 @@ +import argparse +from functools import partial +from itertools import chain +from typing import Dict, Tuple + +import datasets +import pytorch_lightning as pl +import torch +from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset + + +class SpecialDataset(Dataset): + def __init__(self, start=1, end=320, size=32768): # 1048576 32768 + self.size = size + self.features = [] + a = torch.randint(start, end, [size]) + b = torch.randint(start, end, [size]) + c = torch.randint(start, end, [size]) + d = torch.randint(start, end, [size]) + z = torch.zeros([size]).long() + # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0) + # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long() + # self.data = torch.stack([a, a + 1, a + 2]).permute(1, 0).long() + self.data = torch.stack([a, b, a]).permute(1, 0).long() + # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long() + + # input a b c + # output b c x + # label a b c + + # a = torch.randint(start, end, [size]) + # self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5 + # self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1 + # 只能有一种算法,而且第一个值不能用于训练 + # 太陡峭的过度导致难以拟合 + # 搜索空间太大,难以拟合 + + def __len__(self): + return self.size + + def __getitem__(self, idx): + output = {} + data = self.data[idx] + output["input_ids"] = data + output["labels"] = data.clone() + # output["labels"][:2] = 0 + # output["labels"][:2] = vocab_size + output["token_type_ids"] = torch.zeros(data.shape) + return output diff --git a/wit/train.py b/wit/train.py index fc93eac..272626b 100644 --- a/wit/train.py +++ b/wit/train.py @@ -19,12 +19,15 @@ from lit_module import LitModule from tokenization_qwen import QWenTokenizer from logger import TBLogger +from special_dataset import SpecialDataset +from meaning_dataset import MeaningDataset + model_name = "qwen/Qwen-1_8B-Chat" learning_rate = 0.0001 use_tril_attention_mask = None precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" tokenizer_name_or_path = None -train_batch_size = 256 +train_batch_size = 16 val_batch_size = 16 num_proc = 8 max_epochs = 1000 @@ -34,55 +37,22 @@ seed = 42 vocab_size = 4096 -class SpecialDataset(Dataset): - def __init__(self, start=1, end=16, size=32768): # 1048576 32768 - self.size = size - self.features = [] - a = torch.randint(start, end, [size]) - b = torch.randint(start, end, [size]) - c = torch.randint(start, end, [size]) - d = torch.randint(start, end, [size]) - z = torch.zeros([size]).long() - # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0) - # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long() - # self.data = torch.stack([a, a + 1, a + 2]).permute(1, 0).long() - self.data = torch.stack([a, b, a]).permute(1, 0).long() - # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long() - - # a = torch.randint(start, end, [size]) - # self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5 - # self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1 - # 只能有一种算法,而且第一个值不能用于训练 - # 太陡峭的过度导致难以拟合 - # 搜索空间太大,难以拟合 - - def __len__(self): - return self.size - - def __getitem__(self, idx): - output = {} - data = self.data[idx] - output["input_ids"] = data - output["labels"] = data.clone() - # output["labels"][:2] = 0 - # output["labels"][:2] = vocab_size - output["token_type_ids"] = torch.zeros(data.shape) - return output - - if __name__ == "__main__": if tokenizer_name_or_path is None: tokenizer_name_or_path = model_name set_seed(seed) - # lightning module model_dir = snapshot_download(model_name) lit_module = LitModule(model_dir, learning_rate, use_tril_attention_mask) tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") - train_dataset, val_dataset = random_split(SpecialDataset(), [0.95, 0.05]) + # raw_dataset = SpecialDataset() + raw_dataset = MeaningDataset(start=131072, end=1048576, size=32768) + train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) + + # daf = next(iter(train_dataset))["input_ids"].numpy().tolist() train_dataloader = DataLoader( train_dataset,