From 33d1e226551029381f3771b8ea9e4dd132424f11 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 7 Apr 2024 00:25:21 +0800 Subject: [PATCH] Refine meaning dataset. --- test/tree.py | 67 +++++++++ wit/meaning_dataset.md | 38 +++++ wit/meaning_dataset.py | 314 ++++++++++++++++++++++++++++------------- wit/train.py | 12 +- 4 files changed, 325 insertions(+), 106 deletions(-) create mode 100644 test/tree.py create mode 100644 wit/meaning_dataset.md diff --git a/test/tree.py b/test/tree.py new file mode 100644 index 0000000..7878929 --- /dev/null +++ b/test/tree.py @@ -0,0 +1,67 @@ +import numpy as np + +a = np.array([0, 1, 32 + 1, (32 + 1) * 16, 4, 5, 6, 7, 8, 8]).astype(np.uint32) +b = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 8]).astype(np.uint32) + + +d = np.ones(a.shape, dtype=np.uint32) +d = (d * 0xFFFFFFFF) << (b * 4) + +c = a.astype(np.uint32) + +cc = ( + ((c & 0xF) << 28) + + ((c & 0xF0) << 20) + + ((c & 0xF00) << 12) + + ((c & 0xF000) << 4) + + ((c & 0xF0000) >> 4) + + ((c & 0xF00000) >> 12) + + ((c & 0xF000000) >> 20) + + ((c & 0xF0000000) >> 28) +) +cc = (cc >> ((8 - b) * 4)) + d + +print(cc[3] == 4294963218) + +b = np.ones((10)).astype(np.int32) + + +def get_tree_str_new(tree, prefix): + if isinstance(tree, dict): + base = "" + last_is_dict = None + for key, value in tree.items(): + new_prefix = (len(str(key)) + 2) * " " + prefix + dict_string = get_tree_str_new(value, new_prefix) + if dict_string: + base += "\n" + prefix + str(key) + ": " + dict_string + last_is_dict = True + else: + base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " " + last_is_dict = False + return base + return None + + +tree = { + 112377: { + 2944: {228: 228, 263: 263, 252: 252, 396: 396}, + 10024: { + 1424: {189: 189, 209: 209, 200: 200, 102: 102, 178: 178, 22: 22, 9: 9}, + 1053: 432, + 1350: {68: 68, 200: 200, 50: 50, 17: 17, 36: 36, 283: 283}, + 7: 7, + }, + 18196: 322, + 13373: { + 1420: {99: 99, 189: 189, 163: 163}, + 2109: {320: 320, 92: 92, 95: 95, 224: 224, 435: 435, 4: 4, 373: 373, 27: 27, 228: 228}, + 708: 708, + 2196: {27: 27, 157: 157, 87: 87, 231: 231}, + 401: 401, + }, + } +} + + +print(get_tree_str_new(tree, "")) diff --git a/wit/meaning_dataset.md b/wit/meaning_dataset.md new file mode 100644 index 0000000..9b54789 --- /dev/null +++ b/wit/meaning_dataset.md @@ -0,0 +1,38 @@ +# meaning dataset + +meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 + +## 概念 + +1. token表示最终体现的基本数据表达,类似单词。vocab_size表示代表token的数量。 +2. meaning表示一种语义(符号),所有的meaning都由一个编号表达,编号越大表示语义越复杂 +3. 所有的meaning都可以由更低标号表达 +4. 从0到vocab_size的编号表示基本meaning,是不能被拆解的,也就是token +5. meaning通过一层层的向低编号的meaning进行组合替换,最终形成一个最底层是token的树形数据 +6. level表示当前token相对于root meaning的距离 +7. idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的index,高位无用的位用1填充 +8. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构 +9. get_seq_mask返回一个sequence每个token在对应level是不是对应的index +10. meaning_height +11. meaning_weight + +``` +vocab_size = 256 meaning = 115200 + + 115200 + / | \ + 10240 1100 12322 + / | \ / \ / | \ + 512 32 1201 245 233 3214 532 324 + / \ / \ / \ | / \ + 123 42 320 500 1231 23 324 93 176 + / \ / \ / \ / \ + 176 11 255 129 129 99 211 111 + +sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176 +level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3 +idx at 0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1 +idx at 1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2 +idx 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33 + +``` \ No newline at end of file diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index 6d4057a..ada1531 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -11,8 +11,7 @@ from torch.utils.data import BatchSampler class MeaningMap: - - def __init__(self, size=1048576, vocab_size=4096, max_subitem=10): + def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True): self.size = size self.vocab_size = vocab_size self.max_subitem = max_subitem @@ -20,99 +19,186 @@ class MeaningMap: path = "./data/" file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) file = path + file - file_map = file + "_map" + ".npy" - file_start = file + "_start" + ".npy" - file_len = file + "_len" + ".npy" - file_data = file + "_data" + ".npy" + file_slhwm = file + "_slhwm" + ".npy" + file_dli = file + "_dli" + ".npy" if not os.path.exists(path): os.mkdir(path) - if ( - os.path.exists(file_start) - and os.path.exists(file_len) - and os.path.exists(file_data) - and os.path.exists(file_map) - ): + if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache: print("Load from disk cache: " + file) - self.ms_map = np.load(file_map) - self.ms_data = np.load(file_data) - self.ms_start = np.load(file_start) - self.ms_len = np.load(file_len) + slhwm = np.load(file_slhwm) + dli = np.load(file_dli) + self.ms_map = slhwm[:, 4:] + self.ms_data = dli[:, 0] + self.ms_start = slhwm[:, 0] + self.ms_len = slhwm[:, 1] + self.ms_level = dli[:, 1] + self.ms_idx = dli[:, 2].astype(np.uint32) + self.ms_height = slhwm[:, 2] + self.ms_weight = slhwm[:, 3] print("Load end") else: print("Disk cache miss, build new one.") - mm = np.empty((size, max_subitem), dtype=np.int32) + map = np.empty((size, max_subitem), dtype=np.uint32) index = np.arange(0, size) - mm = np.random.random((size, max_subitem)) + map = np.random.random((size, max_subitem)) - mask_zero = mm.copy() + mask_zero = map.copy() mask_zero[:, 0] = 0.0 mask_zero.sort(axis=1) thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) mask_zero = mask_zero > thre - item_sum = mm.sum(axis=1) + item_sum = map.sum(axis=1) scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) - mm = mm * scale - mm[mask_zero] = 0 + map = map * scale - mm[:vocab_size, 0] = np.arange(0, vocab_size) - mm[:vocab_size, 1:] = 0 - mm = mm.astype(np.int32) + map[mask_zero] = 0 - ms = [] # meaning sequence + map[:vocab_size, 0] = np.arange(0, vocab_size) + map[:vocab_size, 1:] = 0 + map = map.astype(np.uint32) + + ms_data = [] # meaning sequence + ms_level = [] # meaning level, vocab's level is 0 + ms_idx = [] # meaning index of lowest level ms_start = [] # meaning sequence start ms_len = [] # meaning sequence length + ms_height = [] # meaning tree height + ms_weight = [] # meaning tree weight index = 0 for i in range(self.vocab_size): - ms.append([i]) + ms_data.append([i]) + ms_level.append([0]) + ms_idx.append([0]) ms_start.append(index) ms_len.append(1) index = index + 1 + ms_height.append(0) + ms_weight.append(1) for i in range(self.vocab_size, size): - m = mm[i] + m = map[i] m = m[m > 0] ma = [] - for newm in m.tolist(): - ma = ma + ms[newm] - ms.append(ma) + ml = [] + mi = [] + for i, newm in enumerate(m.tolist()): + ma = ma + ms_data[newm] + ml = ml + [x + 1 for x in ms_level[newm]] + mi = mi + ([0xFFFFFFF0 + i] if newm < self.vocab_size else [n * 16 + i for n in ms_idx[newm]]) + ms_data.append(ma) ms_start.append(index) ms_len.append(len(ma)) + ms_level.append(ml) + ms_idx.append(mi) index = index + len(ma) + ms_height.append(max([-1] + [ms_height[sub_m] for sub_m in m.tolist()]) + 1) + ms_weight.append(sum(ms_weight[sub_m] for sub_m in m.tolist())) - ms_data = list(chain(*ms)) - np.save(file_map, np.array(mm).astype(np.int32)) - np.save(file_data, np.array(ms_data).astype(np.int32)) - np.save(file_start, np.array(ms_start).astype(np.int32)) - np.save(file_len, np.array(ms_len).astype(np.int32)) + # offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28] + # for idxmi, mi in enumerate(ms_idx): + # level = ms_level[idxmi] + # for idxnum, num in enumerate(mi): + # l = level[idxnum] + # elements = [(num >> offset) & 0xF for offset in offsets[l:0:-1]] + # num = (num >> (l * 4)) << (l * 4) + # num += sum(elem << (i * 4) for i, elem in enumerate(elements)) + # mi[idxnum] = num - self.ms_map = mm - self.ms_data = ms_data + ms_data = np.array(list(chain(*ms_data))).astype(np.int32) + ms_level = np.array(list(chain(*ms_level))).astype(np.int32) + ms_idx = np.array(list(chain(*ms_idx))).astype(np.uint32) + + d = np.ones(ms_idx.shape, dtype=np.uint32) + d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32) + ms_idx = ( + ((ms_idx & 0xF) << 28) + + ((ms_idx & 0xF0) << 20) + + ((ms_idx & 0xF00) << 12) + + ((ms_idx & 0xF000) << 4) + + ((ms_idx & 0xF0000) >> 4) + + ((ms_idx & 0xF00000) >> 12) + + ((ms_idx & 0xF000000) >> 20) + + ((ms_idx & 0xF0000000) >> 28) + ) + ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32) + + ms_start = np.array(ms_start).astype(np.uint32) + ms_height = np.array(ms_height).astype(np.uint32) + ms_weight = np.array(ms_weight).astype(np.uint32) + ms_len = np.array(ms_len).astype(np.uint32) + ms_map = map.astype(np.uint32) + + slhwm = np.concatenate( + ( + ms_start.reshape((-1, 1)), + ms_len.reshape((-1, 1)), + ms_height.reshape((-1, 1)), + ms_weight.reshape((-1, 1)), + ms_map, + ), + axis=1, + ) + dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1) + + np.save(file_slhwm, slhwm) + np.save(file_dli, dli) + + self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)] + self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]] self.ms_start = ms_start self.ms_len = ms_len + self.ms_level = ms_level + self.ms_idx = ms_idx + self.ms_height = ms_height + self.ms_weight = ms_weight print("Disk cache build end.") - def get_sequence(self, meaning): + def get_sequence(self, meaning): # return sequence[meaning] start = self.ms_start[meaning] len = self.ms_len[meaning] - return self.ms_data[start : start + len] + return self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_idx[start : start + len] - def get_mapping(self, meaning): - mapping = {} + def get_tree(self, meaning): # return meaning all sub items + tree = {} ms = self.ms_map[meaning] for m in ms[ms > 0].tolist(): - mapping[m] = self.get_mapping(m) if m >= self.vocab_size else m - return mapping + tree[m] = self.get_tree(m) if m >= self.vocab_size else m + return tree def max_length(self): return max(self.ms_len) + def get_tree_str(tree, prefix): + if isinstance(tree, dict): + base = "" + last_is_dict = None + for key, value in tree.items(): + new_prefix = (len(str(key)) + 2) * " " + prefix + dict_string = MeaningMap.get_tree_str(value, new_prefix) + if dict_string: + base += "\n" + prefix + str(key) + ": " + dict_string + last_is_dict = True + else: + base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " " + last_is_dict = False + return base + return None + + def token_frequency(tree, freq): + if isinstance(tree, dict): + for key, value in tree.items(): + if key in freq: + freq[key] = freq[key] + 1 + else: + freq[key] = 1 + MeaningMap.token_frequency(value, freq) + class MeaningDataset(Dataset): - def __init__( self, start=131072, @@ -124,25 +210,34 @@ class MeaningDataset(Dataset): seed=42, data=None, length=None, - mapping=None, + tree=None, + level=None, + idx=None, + use_cache=True, ): - if data != None and length != None and mapping != None: + if data != None and length != None and tree != None and level != None and idx != None: self.data = data self.length = length - self.mapping = mapping + self.tree = tree + self.level = level + self.idx = idx return np.random.seed(seed) - mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576 - self.mapping = [] + map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache) + self.tree = [] self.data = [] + self.level = [] + self.idx = [] self.length = [] meanings = np.random.randint(start, end, size=(size)) for m in meanings: - sq = mm.get_sequence(m) - if len(sq) >= min_seq_len: - self.mapping.append({m: mm.get_mapping(m)}) - self.data.append(sq) - self.length.append(len(sq)) + d, l, i = map.get_sequence(m) + if len(d) >= min_seq_len: + self.tree.append({m: map.get_tree(m)}) + self.data.append(d) + self.level.append(l) + self.idx.append(i) + self.length.append(len(d)) unique, counts = np.unique(self.length, return_counts=True) print("----------------------------------------------------------------") @@ -164,50 +259,34 @@ class MeaningDataset(Dataset): output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) + output["tree"] = self.tree[idx] + output["level"] = self.level[idx] + output["idx"] = self.idx[idx] return output - def get_batch(self, index_list): # must equal sequence length - data = [self.data[i] for i in index_list] + def get_batch(self, idx_list): # must equal sequence length + data = [self.data[i] for i in idx_list] output = {} data = torch.tensor(np.stack(data, axis=0)).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) + output["tree"] = [self.tree[i] for i in idx_list] + output["level"] = [self.level[i] for i in idx_list] + output["idx"] = [self.idx[i] for i in idx_list] return output - def get_token_batch(self, index_list): # must equal sequence length - return [self.data[i] for i in index_list] + def get_token(self, idx): # must equal sequence length + return self.data[idx] - def print_token_batch(self, index_list): # must equal sequence length - data = [self.data[i] for i in index_list] - output = {} - data = torch.tensor(np.stack(data, axis=0)).long() - output["input_ids"] = data - output["labels"] = data.clone() - output["token_type_ids"] = torch.zeros(data.shape) - return output + def get_tree(self, idx): + return self.tree[idx] - def get_mapping_batch(self, index_list): - return [self.mapping[i] for i in index_list] - - def __get_mapping_str__(map, prefix): - if isinstance(map, dict): - base = "" - for key, value in map.items(): - base += prefix + str(key) + "\n" - base += MeaningDataset.__get_mapping_str__(value, prefix + " ") - return base - else: - return "" - - def print_mapping_batch(self, index_list): - tokens = self.get_token_batch(index_list) - map = self.get_mapping_batch(index_list) - s = "--------------------------------------------------------\n" - for i, m in enumerate(map): - s += str(tokens[i]) + "\n" - s += MeaningDataset.__get_mapping_str__(m, "") - s += "--------------------------------------------------------\n" + def print_tree(self, idx): + tokens = self.data[idx] + tree = self.get_tree(idx) + s = str(tokens) + "\n" + s += MeaningMap.get_tree_str(tree, "") return s def split(self, ratio): @@ -215,14 +294,38 @@ class MeaningDataset(Dataset): middle = int(l * ratio) d_shuffle = self.data.copy() l_shuffle = self.length.copy() - m_shuffle = self.mapping.copy() - md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle], mapping=m_shuffle[:middle]) - md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:], mapping=m_shuffle[middle:]) + m_shuffle = self.tree.copy() + level_shuffle = self.level.copy() + i_shuffle = self.idx.copy() + md1 = MeaningDataset( + data=d_shuffle[:middle], + length=l_shuffle[:middle], + tree=m_shuffle[:middle], + level=level_shuffle[:middle], + idx=i_shuffle[:middle], + ) + md2 = MeaningDataset( + data=d_shuffle[middle:], + length=l_shuffle[middle:], + tree=m_shuffle[middle:], + level=level_shuffle[middle:], + idx=i_shuffle[middle:], + ) return md1, md2 + def token_frequency(self): + freq = {} + for t in self.tree: + MeaningMap.token_frequency(t, freq) + return freq + + def get_seq_mask(idx, level, index): + assert index < 15, "index must < 15" + assert level < 8, "level must < 8" + return [((int(i / (16**level)) & 0xF) == index) for i in idx] + class BatchGroupMeaningDataloader(Dataset): - def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True): self.dataset = dataset self.batch_size = batch_size @@ -266,17 +369,28 @@ class BatchGroupMeaningDataloader(Dataset): def __getitem__(self, idx): return self.dataset.get_batch(self.indexBatch[idx]) - def mapping(self, idx): - return self.dataset.get_mapping_batch(self.indexBatch[idx]) + def get_tree(self, idx): + return [self.dataset.get_tree(i) for i in self.indexBatch[idx]] - def print_mapping(self, idx): - return self.dataset.print_mapping_batch(self.indexBatch[idx]) + def print_tree(self, idx): + idx_list = self.indexBatch[idx] + s = "--------------------------------------------------------\n" + for i in idx_list: + s += self.dataset.print_tree(i) + s += "--------------------------------------------------------\n" + return s if __name__ == "__main__": - md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024) + md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) train, val = md.split(0.95) + fdaf = md.__getitem__(920) + print(md.print_tree(920)) + print(md.idx[920]) + fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1) + print(fdasfe) + freq = md.token_frequency() dl = BatchGroupMeaningDataloader(train, 2) length = len(dl) @@ -285,9 +399,9 @@ if __name__ == "__main__": ne2 = next(it) ne3 = next(it) - map1 = dl.mapping(0) - map2 = dl.mapping(1) - print(dl.print_mapping(0)) + map1 = dl.get_tree(0) + map2 = dl.get_tree(1) + print(dl.print_tree(0)) dl = DataLoader( train, diff --git a/wit/train.py b/wit/train.py index be23a24..1260e92 100644 --- a/wit/train.py +++ b/wit/train.py @@ -17,7 +17,7 @@ pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" learning_rate = 0.0001 use_tril_attention_mask = None precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" -train_batch_size = 4 +train_batch_size = 2 val_batch_size = 1 num_proc = 8 max_epochs = 1000 @@ -25,14 +25,14 @@ strategy = "auto" resume_from_ckpt_path = None seed = 42 -vocab_size = 1024 -level_ratio = 4 -level = 6 +vocab_size = 256 +level_ratio = 6 +level = 4 dataset_level = 1 -hidden_size = 2048 # 128 1024 2048 32 +hidden_size = 1024 # 128 1024 2048 32 num_attention_heads = 16 # 8 8 16 -num_hidden_layers = 12 # 6 12 24 3 +num_hidden_layers = 3 # 6 12 24 3 name = "vocab_ratio_level_data_hidden_head_layer" ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}"