diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 61f9d35..44940c5 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -6,14 +6,13 @@ from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler -from dataset.node_tree import NodeTree - -# import warnings -# warnings.filterwarnings("ignore", ".*does not have many workers.*") +from node_tree import NodeTree class MeaningMap: - def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42): + def __init__( + self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, with_parent=False, use_cache=True, seed=42 + ): assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input" assert min_subitem <= max_subitem, "Invalid input" np.random.seed(seed) @@ -21,6 +20,7 @@ class MeaningMap: self.vocab_size = vocab_size self.max_subitem = max_subitem self.min_subitem = min_subitem + self.with_parent = with_parent datastep = 0x8000000 path = "./data/" @@ -108,6 +108,8 @@ class MeaningMap: m = map[i] # 当前meaning的拆分的分支 m = m[m >= 0] # donot cut off the map such as [0] m_len = len(m) # 当前meaning的拆分的分支个数 + if self.with_parent: + m_len = m_len + 1 m_list = m.tolist() assert m_list, "map list can not be empty list" @@ -117,6 +119,8 @@ class MeaningMap: [np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])] ) len_ma = len(idx) + if self.with_parent: + len_ma = len_ma + 1 end = index + len_ma if ms_data.size < end: # 超过存储数据结构的大小,扩展一个datastep容量 @@ -125,10 +129,26 @@ class MeaningMap: ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)]) ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)]) - ms_data[index:end] = ms_data[idx] # 拼接当前meaning的所有token到data数据结构里面 - ms_level[index:end] = ms_level[idx] + 1 # 处理level - ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) # 处理rank_idx - ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) # 处理rank_all + # 拼接当前meaning的所有token到data数据结构里面 + new_data = ms_data[idx] + if self.with_parent: + new_data = np.concatenate([new_data, np.array([i])]) + ms_data[index:end] = new_data + # 处理level + new_level = ms_level[idx] + 1 + if self.with_parent: + new_level = np.concatenate([new_level, np.array([256])]) + ms_level[index:end] = new_level + # 处理rank_idx + new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) + if self.with_parent: + new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])]) + ms_rank_idx[index:end] = new_idx + # 处理rank_all + new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) + if self.with_parent: + new_all = np.concatenate([new_all, np.array([0xFFFFFFF0 + m_len])]) + ms_rank_all[index:end] = new_all ms_start[i] = index ms_end[i] = end @@ -275,6 +295,7 @@ class MeaningDataset(Dataset): min_subitem=1, min_seq_len=2, seed=42, + with_parent=False, use_cache=True, ): np.random.seed(seed) @@ -283,6 +304,7 @@ class MeaningDataset(Dataset): self.vocab_size = vocab_size self.max_subitem = max_subitem self.min_subitem = min_subitem + self.with_parent = with_parent self.use_cache = use_cache self.min_seq_len = min_seq_len print("Build MeaningDataset from MeaningMap.") @@ -310,9 +332,11 @@ class MeaningDataset(Dataset): self.seq_meaning.append(m) seq_len.append(len(d)) + origin_l = l.copy() + origin_l[l >= 255] = l[l >= 255] - 255 dm = np.ones(i.shape, dtype=np.uint32) - dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32) - shift = (8 - l) * 4 + dm = ((dm * 0xFFFFFFFF) << (origin_l * 4)).astype(np.uint32) + shift = (8 - origin_l) * 4 rank_idx = (i & 0xF) << 28 rank_idx = rank_idx + ((i & 0xF0) << 20) rank_idx = rank_idx + ((i & 0xF00) << 12) @@ -351,14 +375,16 @@ class MeaningDataset(Dataset): return len(self.seq) def get_meaning_map(self): - return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache) + return MeaningMap( + self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.with_parent, self.use_cache + ) def set_mask(self, level=None, idx=None): - if self.val_mask_level is not None and self.val_mask_idx is not None: - assert len(self.val_mask_level) > 0, "len must > 0" - assert len(self.val_mask_level) == len(self.val_mask_idx), "mask level and mask index must be same length" - assert isinstance(self.val_mask_level, list), "mask level must be list" - assert isinstance(self.val_mask_idx, list), "mask index must be list" + if level is not None and idx is not None: + assert len(level) > 0, "len must > 0" + assert len(level) == len(idx), "mask level and mask index must be same length" + assert isinstance(level, list), "mask level must be list" + assert isinstance(idx, list), "mask index must be list" self.val_mask_level = level self.val_mask_idx = idx @@ -403,7 +429,7 @@ class MeaningDataset(Dataset): # assert level < 8, "level must < 8" rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF - return rank_idx == (rank_all + index if index < 0 else index) + return (rank_idx == (rank_all + index if index < 0 else index)) & (self.level[idx] < 255) def get_seq_mask_tensor(self, idx_list): if self.val_mask_level is not None and self.val_mask_idx is not None: @@ -418,7 +444,7 @@ class MeaningDataset(Dataset): ) return mask else: - return None + return torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0)) class BatchGroupMeaningDataloader(Dataset): @@ -476,11 +502,19 @@ class BatchGroupMeaningDataloader(Dataset): if __name__ == "__main__": - md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) - md.set_mask([1], [-1]) + md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, with_parent=True, use_cache=False) train, val = md.split(0.95) - fdaf = md.__getitem__(920) - print(md.print_tree(920)) + item = md.__getitem__(920) + + mm = md.get_meaning_map() + mm.get_nodetree(int(item["meaning"][0])).print() + item_seq = item["input_ids"].numpy().tolist() + item_mask = item["val_mask"].numpy().tolist() + print(item_seq) + print(item_mask) + + md.set_mask([1], [-1]) + print(md.rank_idx[920]) print(md.rank_all[920]) mask = md.get_seq_mask(920, 0, -1) @@ -491,72 +525,6 @@ if __name__ == "__main__": print(mask) mask = md.get_seq_mask(920, 1, 1) print(mask) - assert all( - np.equal( - mask[0:57], - np.array( - [ - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - False, - True, - True, - True, - True, - True, - True, - True, - True, - True, - False, - False, - False, - False, - False, - False, - True, - False, - False, - False, - False, - False, - False, - False, - True, - True, - True, - True, - False, - False, - False, - False, - True, - True, - True, - True, - True, - True, - True, - True, - True, - False, - False, - False, - False, - False, - False, - ] - ), - ) - ), "False" md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False) md.set_mask([0, 1], [0, -1]) diff --git a/wit/doc/meaning_dataset.md b/wit/doc/meaning_dataset.md index 47e6294..6513323 100644 --- a/wit/doc/meaning_dataset.md +++ b/wit/doc/meaning_dataset.md @@ -18,6 +18,10 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 12. meaning_height 当前meaning的总高度 13. meaning_weight 当前meaning的总宽度 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 +15. with_parent 表示每个meaning的拆分都以当前meaning编号结束,sequence会插入很多枝节点,而不仅仅是叶节点 +16. ms_level 大于255表示这个token是parent(level 0用256表示),生成val mask的时候是False +17. ms_map 表示每个meaning拆解的sub meaning +18. index must < 15,level must < 8 ## code