diff --git a/wit/configuration.py b/wit/configuration.py index d8c419f..4fab7fd 100644 --- a/wit/configuration.py +++ b/wit/configuration.py @@ -40,6 +40,7 @@ class MeaningDatasetConfig: def __init__(self): self.start = 10000 self.end = 200000 + self.reserve_vocab = 0 self.size = None self.min_subitem = 2 self.max_subitem = 10 diff --git a/wit/inference.py b/wit/inference.py index 80ef079..3449500 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -21,7 +21,7 @@ if __name__ == "__main__": runner = ModelRunner(qwen.llm) - val = ds.InitValDataset(conf).dataset + _, val = ds.InitDataset(conf).dataset md = val.meaning_dataset map = md.get_meaning_map() diff --git a/wit/meaning/__init__.py b/wit/meaning/__init__.py index 91f4e26..fe50a54 100644 --- a/wit/meaning/__init__.py +++ b/wit/meaning/__init__.py @@ -1,2 +1 @@ from .dataset import InitDataset -from .dataset import InitValDataset diff --git a/wit/meaning/dataset.py b/wit/meaning/dataset.py index 55ce971..dcba06b 100644 --- a/wit/meaning/dataset.py +++ b/wit/meaning/dataset.py @@ -35,10 +35,11 @@ def InitDataset(config): size = c.size end = c.end seed = c.seed + reserve_vocab = c.reserve_vocab path = "./data/" conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}" - conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt" + conf_name = conf_name + f"_vocab{vocab}_reserve_vocab{reserve_vocab}_stride{c.stride}_tree{c.with_tree}.pt" trainfile = path + f"MeaningDataset_train" + conf_name valfile = path + f"MeaningDataset_val" + conf_name if not os.path.exists(path): @@ -56,6 +57,7 @@ def InitDataset(config): start, end, vocab, + reserve_vocab, size, c.max_subitem, c.min_subitem, @@ -74,58 +76,3 @@ def InitDataset(config): ) val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) return train_dataloader, val_dataloader - - -def InitValDataset(config): - - val_batch_size = config.val_batch_size - num_proc = config.num_proc - if config.dataset.name == "special": - raw_dataset = SpecialDataset() - train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) - val_dataloader = DataLoader( - val_dataset, - batch_size=val_batch_size, - num_workers=num_proc, - persistent_workers=True, - ) - return val_dataloader - - if config.dataset.name == "meaning": - c = config.dataset.meaning - vocab = config.model_config.vocab_size - start = c.start - size = c.size - end = c.end - seed = c.seed - - path = "./data/" - conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}" - conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt" - valfile = path + f"MeaningDataset_val" + conf_name - if not os.path.exists(path): - os.mkdir(path) - if os.path.exists(valfile): - print(f"INFO: Load dataset from {valfile}") - val_dataset = torch.load(valfile, weights_only=False) - val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) - print(f"INFO: Load dataset end") - else: - raw_dataset = MeaningDataset( - start, - end, - vocab, - size, - c.max_subitem, - c.min_subitem, - stride=c.stride, - with_tree=c.with_trees, - seed=seed, - ) - raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) - train_dataset, val_dataset = raw_dataset.split(0.9) - torch.save(val_dataset, valfile) - print(f"INFO: Build and save dataset end") - - val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) - return val_dataloader diff --git a/wit/meaning/meaning_dataset.py b/wit/meaning/meaning_dataset.py index bc22f65..0baa894 100644 --- a/wit/meaning/meaning_dataset.py +++ b/wit/meaning/meaning_dataset.py @@ -20,6 +20,7 @@ class MeaningMap: self, size=1048576, vocab_size=4096, + reserve_vocab=0, max_subitem=10, min_subitem=1, stride=1, @@ -31,15 +32,17 @@ class MeaningMap: assert min_subitem <= max_subitem, "Invalid input" np.random.seed(seed) self.size = size + self.reserve_vocab = reserve_vocab - self.special_vocab_size = 0 + self.special_vocab = 0 if stride > 1: - self.special_vocab_size = self.special_vocab_size + 1 - vocab_of_stride = vocab_size - self.special_vocab_size + self.special_vocab = self.special_vocab + 1 + vocab_of_stride = vocab_size - self.special_vocab if with_tree: - self.special_vocab_size = self.special_vocab_size + 1 - vocab_of_tree = vocab_size - self.special_vocab_size - self.normal_vocab_size = vocab_size - self.special_vocab_size + self.special_vocab = self.special_vocab + 1 + vocab_of_tree = vocab_size - self.special_vocab + assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special" + self.normal_vocab = vocab_size - self.reserve_vocab self.max_subitem = max_subitem self.min_subitem = min_subitem @@ -49,7 +52,7 @@ class MeaningMap: path = "./data/" file = "structured_language_" + str(size) + "_" + str(vocab_size) - file += "_" + str(max_subitem) + "_" + str(min_subitem) + file += "_" + str(reserve_vocab) + "_" + str(max_subitem) + "_" + str(min_subitem) file += "_" + str(stride) + "_" + str(with_tree) + "_" + str(seed) file_prop = path + file + "_prop.npy" file_data = path + file + "_data.npy" @@ -100,8 +103,8 @@ class MeaningMap: map[mask_zero] = -1 - map[: self.normal_vocab_size, 0] = np.arange(0, self.normal_vocab_size) - map[: self.normal_vocab_size, 1:] = -1 + map[: self.normal_vocab, 0] = np.arange(0, self.normal_vocab) + map[: self.normal_vocab, 1:] = -1 ms_level = [] # meaning level, vocab's level is 0 ms_rank_idx = [] # meaning index of all level @@ -118,7 +121,7 @@ class MeaningMap: index = 0 - for i in range(self.normal_vocab_size): + for i in range(self.normal_vocab): ms_data[index] = i ms_level[index] = 0 ms_rank_idx[index] = 0xFFFFFFF @@ -135,14 +138,14 @@ class MeaningMap: ms_weight[i] = 1 index = index + stride - for i in range(self.normal_vocab_size, size): + for i in range(self.normal_vocab, size): m = map[i] # 当前meaning的拆分的分支 m = m[m >= 0] # donot cut off the map such as [0] m_len = len(m) # 当前meaning的拆分的分支个数 m_list = m.tolist() assert m_list, "map list can not be empty list" - # 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(= 0].tolist(): - if m >= self.normal_vocab_size: + if m >= self.normal_vocab: common_to_current[-1] = common_to_current[-1] + 1 level_change(ms_map, m, current_to_common, common_to_current) else: @@ -319,6 +322,7 @@ class MeaningDataset(Dataset): start, end, vocab_size, + reserve_vocab=0, size=None, max_subitem=10, min_subitem=1, @@ -332,6 +336,7 @@ class MeaningDataset(Dataset): self.start = start self.end = end self.vocab_size = vocab_size + self.reserve_vocab = reserve_vocab self.max_subitem = max_subitem self.min_subitem = min_subitem self.stride = stride @@ -407,7 +412,14 @@ class MeaningDataset(Dataset): def get_meaning_map(self): return MeaningMap( - self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.stride, self.with_tree, self.use_cache + self.end, + self.vocab_size, + self.reserve_vocab, + self.max_subitem, + self.min_subitem, + self.stride, + self.with_tree, + self.use_cache, ) def set_mask(self, level=None, idx=None): diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 379427c..397f9df 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -44,7 +44,7 @@ if __name__ == "__main__": qwen.llm.hook_attention = DumpQK - val = ds.InitValDataset(conf).dataset + _, val = ds.InitDataset(conf).dataset md = val.meaning_dataset map = md.get_meaning_map()