diff --git a/wit/meaning/meaning_dataset.py b/wit/meaning/meaning_dataset.py index c5e043d..a61efe2 100644 --- a/wit/meaning/meaning_dataset.py +++ b/wit/meaning/meaning_dataset.py @@ -15,6 +15,7 @@ else: class MeaningMap: + def __init__( self, size=1048576, @@ -30,7 +31,16 @@ class MeaningMap: assert min_subitem <= max_subitem, "Invalid input" np.random.seed(seed) self.size = size - self.vocab_size = vocab_size + + self.special_vocab_size = 0 + if stride > 1: + self.special_vocab_size = self.special_vocab_size + 1 + vocab_of_stride = vocab_size - self.special_vocab_size + if with_tree: + self.special_vocab_size = self.special_vocab_size + 1 + vocab_of_tree = vocab_size - self.special_vocab_size + self.normal_vocab_size = vocab_size - self.special_vocab_size + self.max_subitem = max_subitem self.min_subitem = min_subitem self.with_tree = with_tree @@ -89,8 +99,8 @@ class MeaningMap: map[mask_zero] = -1 - map[:vocab_size, 0] = np.arange(0, vocab_size) - map[:vocab_size, 1:] = -1 + map[: self.normal_vocab_size, 0] = np.arange(0, self.normal_vocab_size) + map[: self.normal_vocab_size, 1:] = -1 ms_level = [] # meaning level, vocab's level is 0 ms_rank_idx = [] # meaning index of all level @@ -106,13 +116,14 @@ class MeaningMap: ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level index = 0 - for i in range(self.vocab_size): + + for i in range(self.normal_vocab_size): ms_data[index] = i ms_level[index] = 0 ms_rank_idx[index] = 0xFFFFFFF ms_rank_all[index] = 0xFFFFFFF for ind in range(index + 1, index + stride): - ms_data[ind] = -1 + ms_data[ind] = vocab_of_stride ms_level[ind] = 511 ms_rank_idx[ind] = 0xFFFFFFF ms_rank_all[ind] = 0xFFFFFFF @@ -123,7 +134,7 @@ class MeaningMap: ms_weight[i] = 1 index = index + stride - for i in range(self.vocab_size, size): + for i in range(self.normal_vocab_size, size): m = map[i] # 当前meaning的拆分的分支 m = m[m >= 0] # donot cut off the map such as [0] m_len = len(m) # 当前meaning的拆分的分支个数 @@ -132,7 +143,7 @@ class MeaningMap: m_list = m.tolist() assert m_list, "map list can not be empty list" - # 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(= 0].tolist(): - if m >= self.vocab_size: + if m >= nvs: pn = NodeTree(str(m), parent) - get_tree_node(ms_map, m, vocab_size, pn, seqlist) + get_tree_node(ms_map, m, nvs, pn, seqlist) else: pn = NodeTree("<" + str(m) + ">", parent) seqlist.append(pn) root = NodeTree(str(meaning)) seqlist = [] - get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist) + get_tree_node(self.ms_map, meaning, self.normal_vocab_size, root, seqlist) root.seq_node = seqlist return root @@ -246,7 +257,7 @@ class MeaningMap: def level_change(ms_map, meaning, current_to_common, common_to_current): ms = ms_map[meaning] for m in ms[ms >= 0].tolist(): - if m >= self.vocab_size: + if m >= self.normal_vocab_size: common_to_current[-1] = common_to_current[-1] + 1 level_change(ms_map, m, current_to_common, common_to_current) else: @@ -526,8 +537,8 @@ if __name__ == "__main__": tracemalloc.start() md = MeaningDataset( + 10000, 100000, - 300000, min_subitem=2, max_subitem=6, vocab_size=32, @@ -541,7 +552,16 @@ if __name__ == "__main__": print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB") tracemalloc.stop() - md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False) + md = MeaningDataset( + 10000, + 100000, + vocab_size=32, + stride=2, + min_subitem=2, + max_subitem=6, + with_tree=False, + use_cache=False, + ) item = md.__getitem__(920) mm = md.get_meaning_map() mm.get_nodetree(int(item["meaning"][0])).print() @@ -550,7 +570,7 @@ if __name__ == "__main__": print(item_seq) print(item_mask) - md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False) + md = MeaningDataset(10000, 100000, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False) train, val = md.split(0.95) item = md.__getitem__(920)