From 9c75b8920d1530b91b50dbed33a354b920cd5b46 Mon Sep 17 00:00:00 2001 From: Colin <> Date: Thu, 7 Aug 2025 17:36:43 +0800 Subject: [PATCH] Add stride support in meaning_dataset. --- wit/dataset/meaning_dataset.py | 65 +++++++++++++++++++++++----------- wit/doc/meaning_dataset.md | 4 +-- 2 files changed, 47 insertions(+), 22 deletions(-) diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 44940c5..f637b62 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -11,7 +11,15 @@ from node_tree import NodeTree class MeaningMap: def __init__( - self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, with_parent=False, use_cache=True, seed=42 + self, + size=1048576, + vocab_size=4096, + max_subitem=10, + min_subitem=1, + stride=1, + with_tree=False, + use_cache=True, + seed=42, ): assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input" assert min_subitem <= max_subitem, "Invalid input" @@ -20,7 +28,8 @@ class MeaningMap: self.vocab_size = vocab_size self.max_subitem = max_subitem self.min_subitem = min_subitem - self.with_parent = with_parent + self.with_tree = with_tree + self.stride = stride datastep = 0x8000000 path = "./data/" @@ -93,22 +102,27 @@ class MeaningMap: index = 0 for i in range(self.vocab_size): - ms_data[i] = i + ms_data[index] = i + ms_level[index] = 0 + ms_rank_idx[index] = 0xFFFFFFF + ms_rank_all[index] = 0xFFFFFFF + for ind in range(index + 1, index + stride): + ms_data[ind] = -1 + ms_level[ind] = 511 + ms_rank_idx[ind] = 0xFFFFFFF + ms_rank_all[ind] = 0xFFFFFFF ms_start[i] = index - ms_end[i] = index + 1 - ms_len[i] = 1 - ms_level[i] = 0 - ms_rank_idx[i] = 0xFFFFFFF - ms_rank_all[i] = 0xFFFFFFF + ms_end[i] = index + stride + ms_len[i] = stride ms_height[i] = 0 ms_weight[i] = 1 - index = index + 1 + index = index + stride for i in range(self.vocab_size, size): m = map[i] # 当前meaning的拆分的分支 m = m[m >= 0] # donot cut off the map such as [0] m_len = len(m) # 当前meaning的拆分的分支个数 - if self.with_parent: + if self.with_tree: m_len = m_len + 1 m_list = m.tolist() assert m_list, "map list can not be empty list" @@ -119,7 +133,7 @@ class MeaningMap: [np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])] ) len_ma = len(idx) - if self.with_parent: + if self.with_tree: len_ma = len_ma + 1 end = index + len_ma @@ -131,22 +145,22 @@ class MeaningMap: # 拼接当前meaning的所有token到data数据结构里面 new_data = ms_data[idx] - if self.with_parent: + if self.with_tree: new_data = np.concatenate([new_data, np.array([i])]) ms_data[index:end] = new_data # 处理level new_level = ms_level[idx] + 1 - if self.with_parent: - new_level = np.concatenate([new_level, np.array([256])]) + if self.with_tree: + new_level = np.concatenate([new_level, np.array([255 + 1])]) ms_level[index:end] = new_level # 处理rank_idx new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) - if self.with_parent: + if self.with_tree: new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])]) ms_rank_idx[index:end] = new_idx # 处理rank_all new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) - if self.with_parent: + if self.with_tree: new_all = np.concatenate([new_all, np.array([0xFFFFFFF0 + m_len])]) ms_rank_all[index:end] = new_all @@ -295,7 +309,8 @@ class MeaningDataset(Dataset): min_subitem=1, min_seq_len=2, seed=42, - with_parent=False, + stride=1, + with_tree=False, use_cache=True, ): np.random.seed(seed) @@ -304,7 +319,8 @@ class MeaningDataset(Dataset): self.vocab_size = vocab_size self.max_subitem = max_subitem self.min_subitem = min_subitem - self.with_parent = with_parent + self.stride = stride + self.with_tree = with_tree self.use_cache = use_cache self.min_seq_len = min_seq_len print("Build MeaningDataset from MeaningMap.") @@ -376,7 +392,7 @@ class MeaningDataset(Dataset): def get_meaning_map(self): return MeaningMap( - self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.with_parent, self.use_cache + self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.stride, self.with_tree, self.use_cache ) def set_mask(self, level=None, idx=None): @@ -502,7 +518,16 @@ class BatchGroupMeaningDataloader(Dataset): if __name__ == "__main__": - md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, with_parent=True, use_cache=False) + md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False) + item = md.__getitem__(920) + mm = md.get_meaning_map() + mm.get_nodetree(int(item["meaning"][0])).print() + item_seq = item["input_ids"].numpy().tolist() + item_mask = item["val_mask"].numpy().tolist() + print(item_seq) + print(item_mask) + + md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False) train, val = md.split(0.95) item = md.__getitem__(920) diff --git a/wit/doc/meaning_dataset.md b/wit/doc/meaning_dataset.md index 6513323..6d1fdf8 100644 --- a/wit/doc/meaning_dataset.md +++ b/wit/doc/meaning_dataset.md @@ -18,8 +18,8 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 12. meaning_height 当前meaning的总高度 13. meaning_weight 当前meaning的总宽度 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 -15. with_parent 表示每个meaning的拆分都以当前meaning编号结束,sequence会插入很多枝节点,而不仅仅是叶节点 -16. ms_level 大于255表示这个token是parent(level 0用256表示),生成val mask的时候是False +15. with_tree 表示每个meaning的拆分都以当前meaning编号结束,sequence会插入很多枝节点,而不仅仅是叶节点 +16. ms_level 大于255表示这个token是tree(level 0用255表示),大于511表示这个token是stride(level 0用511表示),生成val mask的时候是False 17. ms_map 表示每个meaning拆解的sub meaning 18. index must < 15,level must < 8