diff --git a/wit/meaning_dataset.md b/wit/meaning_dataset.md index 9b54789..66b5ca4 100644 --- a/wit/meaning_dataset.md +++ b/wit/meaning_dataset.md @@ -7,14 +7,16 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 1. token表示最终体现的基本数据表达,类似单词。vocab_size表示代表token的数量。 2. meaning表示一种语义(符号),所有的meaning都由一个编号表达,编号越大表示语义越复杂 3. 所有的meaning都可以由更低标号表达 -4. 从0到vocab_size的编号表示基本meaning,是不能被拆解的,也就是token +4. 从0到(vocab_size-1)的编号表示基本meaning,是不能被拆解的,也就是token 5. meaning通过一层层的向低编号的meaning进行组合替换,最终形成一个最底层是token的树形数据 6. level表示当前token相对于root meaning的距离 -7. idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的index,高位无用的位用1填充 +7. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充 +7. rank_all表示当前token在不同层的分子个数,每4位表示在一层里面的编号,低4位表示最低层级的rank_all,高位无用的位用1填充 8. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构 -9. get_seq_mask返回一个sequence每个token在对应level是不是对应的index -10. meaning_height -11. meaning_weight +9. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个 +10. meaning_height 当前meaning的总高度 +11. meaning_weight 当前meaning的总宽度 + ``` vocab_size = 256 meaning = 115200 diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index d8d18d7..416134e 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -8,6 +8,7 @@ from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler +import copy class MeaningMap: @@ -18,7 +19,7 @@ class MeaningMap: path = "./data/" file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) - file = path + file + file = path + file + ".npz" if not os.path.exists(path): os.mkdir(path) @@ -26,13 +27,14 @@ class MeaningMap: print("Load from disk cache: " + file) loaded = np.load(file) slhwm = loaded["slhwm"] - dli = loaded["dli"] + dlra = loaded["dlra"] self.ms_map = slhwm[:, 4:] - self.ms_data = dli[:, 0] + self.ms_data = dlra[:, 0] self.ms_start = slhwm[:, 0] self.ms_len = slhwm[:, 1] - self.ms_level = dli[:, 1] - self.ms_idx = dli[:, 2].astype(np.uint32) + self.ms_level = dlra[:, 1] + self.ms_rank_idx = dlra[:, 2].astype(np.uint32) + self.ms_rank_all = dlra[:, 3].astype(np.uint32) self.ms_height = slhwm[:, 2] self.ms_weight = slhwm[:, 3] print("Load end") @@ -61,7 +63,8 @@ class MeaningMap: ms_data = [] # meaning sequence ms_level = [] # meaning level, vocab's level is 0 - ms_idx = [] # meaning index of lowest level + ms_rank_idx = [] # meaning index of all level + ms_rank_all = [] # meaning all of all level ms_start = [] # meaning sequence start ms_len = [] # meaning sequence length ms_height = [] # meaning tree height @@ -70,7 +73,8 @@ class MeaningMap: for i in range(self.vocab_size): ms_data.append(np.array([i])) ms_level.append(np.array([0])) - ms_idx.append(np.array([0])) + ms_rank_idx.append(np.array([0])) + ms_rank_all.append(np.array([0])) ms_start.append(index) ms_len.append(1) ms_height.append(0) @@ -79,59 +83,70 @@ class MeaningMap: for i in range(self.vocab_size, size): m = map[i] - m = m[m >= 0] + m = m[m >= 0] # donot cut off the map such as [0] m_list = m.tolist() + m_len = len(m_list) assert m_list, "map list can not be empty list" ma = np.concatenate([ms_data[newm] for newm in m_list]) ml = np.concatenate([ms_level[newm] + 1 for newm in m_list]) - mi = np.concatenate( + mr = np.concatenate( [ - ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_idx[newm] * 16 + i) + ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i) for i, newm in enumerate(m_list) ] ) - ml = ml[ma > 0] - mi = mi[ma > 0] - ma = ma[ma > 0] + mrl = np.concatenate( + [ + ([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len) + for i, newm in enumerate(m_list) + ] + ) + # ml = ml[ma > 0] # cut off the 0 token, such as [12,32,0,42,32] + # mr = mr[ma > 0] + # mrl = mrl[ma > 0] + # ma = ma[ma > 0] ms_data.append(ma) ms_level.append(ml) - ms_idx.append(mi) + ms_rank_idx.append(mr) + ms_rank_all.append(mrl) ms_start.append(index) ms_len.append(len(ma)) ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1) ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list)) index = index + len(ma) - # offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28] - # for idxmi, mi in enumerate(ms_idx): - # level = ms_level[idxmi] - # for idxnum, num in enumerate(mi): - # l = level[idxnum] - # elements = [(num >> offset) & 0xF for offset in offsets[l:0:-1]] - # num = (num >> (l * 4)) << (l * 4) - # num += sum(elem << (i * 4) for i, elem in enumerate(elements)) - # mi[idxnum] = num - ms_data = np.array(list(chain(*ms_data))).astype(np.int32) ms_level = np.array(list(chain(*ms_level))).astype(np.int32) - ms_idx = np.array(list(chain(*ms_idx))).astype(np.uint32) + ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32) + ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32) - d = np.ones(ms_idx.shape, dtype=np.uint32) + d = np.ones(ms_rank_idx.shape, dtype=np.uint32) d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32) - ms_idx = ( - ((ms_idx & 0xF) << 28) - + ((ms_idx & 0xF0) << 20) - + ((ms_idx & 0xF00) << 12) - + ((ms_idx & 0xF000) << 4) - + ((ms_idx & 0xF0000) >> 4) - + ((ms_idx & 0xF00000) >> 12) - + ((ms_idx & 0xF000000) >> 20) - + ((ms_idx & 0xF0000000) >> 28) + ms_rank_idx = ( + ((ms_rank_idx & 0xF) << 28) + + ((ms_rank_idx & 0xF0) << 20) + + ((ms_rank_idx & 0xF00) << 12) + + ((ms_rank_idx & 0xF000) << 4) + + ((ms_rank_idx & 0xF0000) >> 4) + + ((ms_rank_idx & 0xF00000) >> 12) + + ((ms_rank_idx & 0xF000000) >> 20) + + ((ms_rank_idx & 0xF0000000) >> 28) ) - ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32) + ms_rank_idx = ((ms_rank_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32) + ms_rank_all = ( + ((ms_rank_all & 0xF) << 28) + + ((ms_rank_all & 0xF0) << 20) + + ((ms_rank_all & 0xF00) << 12) + + ((ms_rank_all & 0xF000) << 4) + + ((ms_rank_all & 0xF0000) >> 4) + + ((ms_rank_all & 0xF00000) >> 12) + + ((ms_rank_all & 0xF000000) >> 20) + + ((ms_rank_all & 0xF0000000) >> 28) + ) + ms_rank_all = ((ms_rank_all >> ((8 - ms_level) * 4)) + d).astype(np.uint32) ms_start = np.array(ms_start).astype(np.int32) ms_height = np.array(ms_height).astype(np.int32) @@ -148,15 +163,17 @@ class MeaningMap: ), axis=1, ) - dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1) - np.savez(file, slhwm=slhwm, dli=dli) + dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1) + np.savez(file, slhwm=slhwm, dlra=dlra) + + self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]] + self.ms_level = ms_level + self.ms_rank_idx = ms_rank_idx + self.ms_rank_all = ms_rank_all self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)] - self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]] self.ms_start = ms_start self.ms_len = ms_len - self.ms_level = ms_level - self.ms_idx = ms_idx self.ms_height = ms_height self.ms_weight = ms_weight print("Disk cache build end.") @@ -164,7 +181,12 @@ class MeaningMap: def get_sequence(self, meaning): # return sequence[meaning] start = self.ms_start[meaning] len = self.ms_len[meaning] - return self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_idx[start : start + len] + return ( + self.ms_data[start : start + len], + self.ms_level[start : start + len], + self.ms_rank_idx[start : start + len], + self.ms_rank_all[start : start + len], + ) def get_tree(self, meaning): # return meaning all sub items tree = {} @@ -203,73 +225,70 @@ class MeaningMap: class MeaningDataset(Dataset): + def __init__( self, - start=131072, - end=1048576, - size=32768, - vocab_size=4096, + start, + end, + size, + vocab_size, max_subitem=10, min_seq_len=2, seed=42, - data=None, - length=None, - tree=None, - level=None, - idx=None, use_cache=True, ): - if data != None and length != None and tree != None and level != None and idx != None: - self.data = data - self.length = length - self.tree = tree - self.level = level - self.idx = idx - return np.random.seed(seed) map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache) + np.random.seed(seed) + self.tree = [] - self.data = [] + self.seq = [] self.level = [] - self.idx = [] - self.length = [] + self.rank_idx = [] + self.rank_all = [] + self.seq_meaning = [] + self.m_height = map.ms_height + self.m_weight = map.ms_weight meanings = np.random.randint(start, end, size=(size)) + + seq_len = [] for m in meanings: - d, l, i = map.get_sequence(m) + d, l, i, a = map.get_sequence(m) if len(d) >= min_seq_len: self.tree.append({m: map.get_tree(m)}) - self.data.append(d) + self.seq.append(d) self.level.append(l) - self.idx.append(i) - self.length.append(len(d)) + self.rank_idx.append(i) + self.rank_all.append(a) + self.seq_meaning.append(m) + seq_len.append(len(d)) - unique, counts = np.unique(self.length, return_counts=True) + unique, counts = np.unique(seq_len, return_counts=True) print("----------------------------------------------------------------") print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start)) - print("MeaningDataset size:" + str(len(self.length))) + print("MeaningDataset size:" + str(len(seq_len))) print("MeaningDataset max sequence length:" + str(max(unique))) print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)])) print("----------------------------------------------------------------") def __len__(self): - return len(self.data) + return len(self.seq) def len(self): - return len(self.data) + return len(self.seq) def __getitem__(self, idx): output = {} - data = torch.tensor(self.data[idx]).long() + data = torch.tensor(self.seq[idx]).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) output["tree"] = self.tree[idx] output["level"] = self.level[idx] - output["idx"] = self.idx[idx] return output def get_batch(self, idx_list): # must equal sequence length - data = [self.data[i] for i in idx_list] + data = [self.seq[i] for i in idx_list] output = {} data = torch.tensor(np.stack(data, axis=0)).long() output["input_ids"] = data @@ -277,45 +296,35 @@ class MeaningDataset(Dataset): output["token_type_ids"] = torch.zeros(data.shape) output["tree"] = [self.tree[i] for i in idx_list] output["level"] = [self.level[i] for i in idx_list] - output["idx"] = [self.idx[i] for i in idx_list] return output def get_token(self, idx): # must equal sequence length - return self.data[idx] + return self.seq[idx] def get_tree(self, idx): return self.tree[idx] def print_tree(self, idx): - tokens = self.data[idx] + tokens = self.seq[idx] tree = self.get_tree(idx) s = str(tokens) + "\n" s += MeaningMap.get_tree_str(tree, "") return s + def copy(self, start, end): + new = copy.deepcopy(self) + new.tree = new.tree[start:end] + new.seq = new.seq[start:end] + new.level = new.level[start:end] + new.rank_idx = new.rank_idx[start:end] + new.rank_all = new.rank_all[start:end] + new.seq_meaning = new.seq_meaning[start:end] + return new + def split(self, ratio): - l = len(self.data) + l = self.len() middle = int(l * ratio) - d_shuffle = self.data.copy() - l_shuffle = self.length.copy() - m_shuffle = self.tree.copy() - level_shuffle = self.level.copy() - i_shuffle = self.idx.copy() - md1 = MeaningDataset( - data=d_shuffle[:middle], - length=l_shuffle[:middle], - tree=m_shuffle[:middle], - level=level_shuffle[:middle], - idx=i_shuffle[:middle], - ) - md2 = MeaningDataset( - data=d_shuffle[middle:], - length=l_shuffle[middle:], - tree=m_shuffle[middle:], - level=level_shuffle[middle:], - idx=i_shuffle[middle:], - ) - return md1, md2 + return self.copy(0, middle), self.copy(middle, l) def token_frequency(self): freq = {} @@ -323,10 +332,12 @@ class MeaningDataset(Dataset): MeaningMap.token_frequency(t, freq) return freq - def get_seq_mask(idx, level, index): + def get_seq_mask(self, idx, level, index): assert index < 15, "index must < 15" assert level < 8, "level must < 8" - return [((int(i / (16**level)) & 0xF) == index) for i in idx] + rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF + rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF + return rank_idx == (rank_all + index if index < 0 else index) class BatchGroupMeaningDataloader(Dataset): @@ -335,11 +346,11 @@ class BatchGroupMeaningDataloader(Dataset): self.batch_size = batch_size self.drop_last = drop_last - length = dataset.length - unique, counts = np.unique(length, return_counts=True) + seq_len = [len(s) for s in dataset.seq] + unique, counts = np.unique(seq_len, return_counts=True) gl = {} for u in unique: - gl[u] = np.where(length == u)[0] + gl[u] = np.where(seq_len == u)[0] lens = list(gl.keys()) gs = {} @@ -365,7 +376,7 @@ class BatchGroupMeaningDataloader(Dataset): index = index[index_shuffle] self.indexBatch = index print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index))) - print("Dataloader total:" + str(len(length)) + " drop:" + str(len(length) - len(index) * batch_size)) + print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size)) def __len__(self): return len(self.indexBatch) @@ -387,229 +398,109 @@ class BatchGroupMeaningDataloader(Dataset): if __name__ == "__main__": - md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True) + md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) train, val = md.split(0.95) fdaf = md.__getitem__(920) print(md.print_tree(920)) - print(md.idx[920]) - mask = MeaningDataset.get_seq_mask(md.idx[920], 1, 1) + print(md.rank_idx[920]) + print(md.rank_all[920]) + mask = md.get_seq_mask(920, 0, -1) print(mask) - assert mask == [ - False, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - False, - False, - False, - False, - False, - False, - False, - False, - True, - True, - True, - True, - True, - True, - True, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - True, - True, - False, - False, - False, - False, - False, - False, - True, - True, - True, - True, - True, - True, - True, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - True, - True, - True, - True, - True, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - True, - True, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - True, - True, - True, - True, - True, - True, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - ], "False" + mask = md.get_seq_mask(920, 1, 0) + print(mask) + mask = md.get_seq_mask(920, 1, -1) + print(mask) + mask = md.get_seq_mask(920, 1, 1) + print(mask) + assert all( + np.equal( + mask[0:57], + np.array( + [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + True, + False, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + ] + ), + ) + ), "False" freq = md.token_frequency() dl = BatchGroupMeaningDataloader(train, 2) - length = len(dl) - it = iter(dl) - ne1 = next(it) - ne2 = next(it) - ne3 = next(it) + # length = len(dl) + # it = iter(dl) + # ne1 = next(it) + # ne2 = next(it) + # ne3 = next(it) - map1 = dl.get_tree(0) - map2 = dl.get_tree(1) - print(dl.print_tree(0)) + # map1 = dl.get_tree(0) + # map2 = dl.get_tree(1) + # print(dl.print_tree(0)) - dl = DataLoader( - train, - num_workers=1, - persistent_workers=True, - shuffle=False, - ) - it = iter(dl) - ne1 = next(it) - ne2 = next(it) - ne3 = next(it) + # dl = DataLoader( + # train, + # num_workers=1, + # persistent_workers=True, + # shuffle=False, + # ) + # it = iter(dl) + # ne1 = next(it) + # ne2 = next(it) + # ne3 = next(it) - for i in range(10): - print(next(it)["input_ids"].numpy().tolist()) + # for i in range(10): + # print(next(it)["input_ids"].numpy().tolist())