Add reserve_vocaba support, delete InitValDataset.

This commit is contained in:
Colin 2025-08-14 14:51:43 +08:00
parent a983a2b6e6
commit b2fe00c157
6 changed files with 33 additions and 74 deletions

View File

@ -40,6 +40,7 @@ class MeaningDatasetConfig:
def __init__(self): def __init__(self):
self.start = 10000 self.start = 10000
self.end = 200000 self.end = 200000
self.reserve_vocab = 0
self.size = None self.size = None
self.min_subitem = 2 self.min_subitem = 2
self.max_subitem = 10 self.max_subitem = 10

View File

@ -21,7 +21,7 @@ if __name__ == "__main__":
runner = ModelRunner(qwen.llm) runner = ModelRunner(qwen.llm)
val = ds.InitValDataset(conf).dataset _, val = ds.InitDataset(conf).dataset
md = val.meaning_dataset md = val.meaning_dataset
map = md.get_meaning_map() map = md.get_meaning_map()

View File

@ -1,2 +1 @@
from .dataset import InitDataset from .dataset import InitDataset
from .dataset import InitValDataset

View File

@ -35,10 +35,11 @@ def InitDataset(config):
size = c.size size = c.size
end = c.end end = c.end
seed = c.seed seed = c.seed
reserve_vocab = c.reserve_vocab
path = "./data/" path = "./data/"
conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}" conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}"
conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt" conf_name = conf_name + f"_vocab{vocab}_reserve_vocab{reserve_vocab}_stride{c.stride}_tree{c.with_tree}.pt"
trainfile = path + f"MeaningDataset_train" + conf_name trainfile = path + f"MeaningDataset_train" + conf_name
valfile = path + f"MeaningDataset_val" + conf_name valfile = path + f"MeaningDataset_val" + conf_name
if not os.path.exists(path): if not os.path.exists(path):
@ -56,6 +57,7 @@ def InitDataset(config):
start, start,
end, end,
vocab, vocab,
reserve_vocab,
size, size,
c.max_subitem, c.max_subitem,
c.min_subitem, c.min_subitem,
@ -74,58 +76,3 @@ def InitDataset(config):
) )
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
return train_dataloader, val_dataloader return train_dataloader, val_dataloader
def InitValDataset(config):
val_batch_size = config.val_batch_size
num_proc = config.num_proc
if config.dataset.name == "special":
raw_dataset = SpecialDataset()
train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05])
val_dataloader = DataLoader(
val_dataset,
batch_size=val_batch_size,
num_workers=num_proc,
persistent_workers=True,
)
return val_dataloader
if config.dataset.name == "meaning":
c = config.dataset.meaning
vocab = config.model_config.vocab_size
start = c.start
size = c.size
end = c.end
seed = c.seed
path = "./data/"
conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}"
conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt"
valfile = path + f"MeaningDataset_val" + conf_name
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(valfile):
print(f"INFO: Load dataset from {valfile}")
val_dataset = torch.load(valfile, weights_only=False)
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
print(f"INFO: Load dataset end")
else:
raw_dataset = MeaningDataset(
start,
end,
vocab,
size,
c.max_subitem,
c.min_subitem,
stride=c.stride,
with_tree=c.with_trees,
seed=seed,
)
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(val_dataset, valfile)
print(f"INFO: Build and save dataset end")
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
return val_dataloader

View File

@ -20,6 +20,7 @@ class MeaningMap:
self, self,
size=1048576, size=1048576,
vocab_size=4096, vocab_size=4096,
reserve_vocab=0,
max_subitem=10, max_subitem=10,
min_subitem=1, min_subitem=1,
stride=1, stride=1,
@ -31,15 +32,17 @@ class MeaningMap:
assert min_subitem <= max_subitem, "Invalid input" assert min_subitem <= max_subitem, "Invalid input"
np.random.seed(seed) np.random.seed(seed)
self.size = size self.size = size
self.reserve_vocab = reserve_vocab
self.special_vocab_size = 0 self.special_vocab = 0
if stride > 1: if stride > 1:
self.special_vocab_size = self.special_vocab_size + 1 self.special_vocab = self.special_vocab + 1
vocab_of_stride = vocab_size - self.special_vocab_size vocab_of_stride = vocab_size - self.special_vocab
if with_tree: if with_tree:
self.special_vocab_size = self.special_vocab_size + 1 self.special_vocab = self.special_vocab + 1
vocab_of_tree = vocab_size - self.special_vocab_size vocab_of_tree = vocab_size - self.special_vocab
self.normal_vocab_size = vocab_size - self.special_vocab_size assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special"
self.normal_vocab = vocab_size - self.reserve_vocab
self.max_subitem = max_subitem self.max_subitem = max_subitem
self.min_subitem = min_subitem self.min_subitem = min_subitem
@ -49,7 +52,7 @@ class MeaningMap:
path = "./data/" path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size) file = "structured_language_" + str(size) + "_" + str(vocab_size)
file += "_" + str(max_subitem) + "_" + str(min_subitem) file += "_" + str(reserve_vocab) + "_" + str(max_subitem) + "_" + str(min_subitem)
file += "_" + str(stride) + "_" + str(with_tree) + "_" + str(seed) file += "_" + str(stride) + "_" + str(with_tree) + "_" + str(seed)
file_prop = path + file + "_prop.npy" file_prop = path + file + "_prop.npy"
file_data = path + file + "_data.npy" file_data = path + file + "_data.npy"
@ -100,8 +103,8 @@ class MeaningMap:
map[mask_zero] = -1 map[mask_zero] = -1
map[: self.normal_vocab_size, 0] = np.arange(0, self.normal_vocab_size) map[: self.normal_vocab, 0] = np.arange(0, self.normal_vocab)
map[: self.normal_vocab_size, 1:] = -1 map[: self.normal_vocab, 1:] = -1
ms_level = [] # meaning level, vocab's level is 0 ms_level = [] # meaning level, vocab's level is 0
ms_rank_idx = [] # meaning index of all level ms_rank_idx = [] # meaning index of all level
@ -118,7 +121,7 @@ class MeaningMap:
index = 0 index = 0
for i in range(self.normal_vocab_size): for i in range(self.normal_vocab):
ms_data[index] = i ms_data[index] = i
ms_level[index] = 0 ms_level[index] = 0
ms_rank_idx[index] = 0xFFFFFFF ms_rank_idx[index] = 0xFFFFFFF
@ -135,14 +138,14 @@ class MeaningMap:
ms_weight[i] = 1 ms_weight[i] = 1
index = index + stride index = index + stride
for i in range(self.normal_vocab_size, size): for i in range(self.normal_vocab, size):
m = map[i] # 当前meaning的拆分的分支 m = map[i] # 当前meaning的拆分的分支
m = m[m >= 0] # donot cut off the map such as [0] m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m) # 当前meaning的拆分的分支个数 m_len = len(m) # 当前meaning的拆分的分支个数
m_list = m.tolist() m_list = m.tolist()
assert m_list, "map list can not be empty list" assert m_list, "map list can not be empty list"
# 获取每个子meaning的start和end并且生成序列组合成当前meaning完整的叶index<self.normal_vocab_size) # 获取每个子meaning的start和end并且生成序列组合成当前meaning完整的叶index<self.normal_vocab)
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list]) idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
idxidx = np.concatenate( idxidx = np.concatenate(
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])] [np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
@ -246,7 +249,7 @@ class MeaningMap:
root = NodeTree(str(meaning)) root = NodeTree(str(meaning))
seqlist = [] seqlist = []
get_tree_node(self.ms_map, meaning, self.normal_vocab_size, root, seqlist) get_tree_node(self.ms_map, meaning, self.normal_vocab, root, seqlist)
root.seq_node = seqlist root.seq_node = seqlist
return root return root
@ -256,7 +259,7 @@ class MeaningMap:
def level_change(ms_map, meaning, current_to_common, common_to_current): def level_change(ms_map, meaning, current_to_common, common_to_current):
ms = ms_map[meaning] ms = ms_map[meaning]
for m in ms[ms >= 0].tolist(): for m in ms[ms >= 0].tolist():
if m >= self.normal_vocab_size: if m >= self.normal_vocab:
common_to_current[-1] = common_to_current[-1] + 1 common_to_current[-1] = common_to_current[-1] + 1
level_change(ms_map, m, current_to_common, common_to_current) level_change(ms_map, m, current_to_common, common_to_current)
else: else:
@ -319,6 +322,7 @@ class MeaningDataset(Dataset):
start, start,
end, end,
vocab_size, vocab_size,
reserve_vocab=0,
size=None, size=None,
max_subitem=10, max_subitem=10,
min_subitem=1, min_subitem=1,
@ -332,6 +336,7 @@ class MeaningDataset(Dataset):
self.start = start self.start = start
self.end = end self.end = end
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.reserve_vocab = reserve_vocab
self.max_subitem = max_subitem self.max_subitem = max_subitem
self.min_subitem = min_subitem self.min_subitem = min_subitem
self.stride = stride self.stride = stride
@ -407,7 +412,14 @@ class MeaningDataset(Dataset):
def get_meaning_map(self): def get_meaning_map(self):
return MeaningMap( return MeaningMap(
self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.stride, self.with_tree, self.use_cache self.end,
self.vocab_size,
self.reserve_vocab,
self.max_subitem,
self.min_subitem,
self.stride,
self.with_tree,
self.use_cache,
) )
def set_mask(self, level=None, idx=None): def set_mask(self, level=None, idx=None):

View File

@ -44,7 +44,7 @@ if __name__ == "__main__":
qwen.llm.hook_attention = DumpQK qwen.llm.hook_attention = DumpQK
val = ds.InitValDataset(conf).dataset _, val = ds.InitDataset(conf).dataset
md = val.meaning_dataset md = val.meaning_dataset
map = md.get_meaning_map() map = md.get_meaning_map()