Add reserve_vocaba support, delete InitValDataset.

This commit is contained in:
Colin 2025-08-14 14:51:43 +08:00
parent a983a2b6e6
commit b2fe00c157
6 changed files with 33 additions and 74 deletions

View File

@ -40,6 +40,7 @@ class MeaningDatasetConfig:
def __init__(self):
self.start = 10000
self.end = 200000
self.reserve_vocab = 0
self.size = None
self.min_subitem = 2
self.max_subitem = 10

View File

@ -21,7 +21,7 @@ if __name__ == "__main__":
runner = ModelRunner(qwen.llm)
val = ds.InitValDataset(conf).dataset
_, val = ds.InitDataset(conf).dataset
md = val.meaning_dataset
map = md.get_meaning_map()

View File

@ -1,2 +1 @@
from .dataset import InitDataset
from .dataset import InitValDataset

View File

@ -35,10 +35,11 @@ def InitDataset(config):
size = c.size
end = c.end
seed = c.seed
reserve_vocab = c.reserve_vocab
path = "./data/"
conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}"
conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt"
conf_name = conf_name + f"_vocab{vocab}_reserve_vocab{reserve_vocab}_stride{c.stride}_tree{c.with_tree}.pt"
trainfile = path + f"MeaningDataset_train" + conf_name
valfile = path + f"MeaningDataset_val" + conf_name
if not os.path.exists(path):
@ -56,6 +57,7 @@ def InitDataset(config):
start,
end,
vocab,
reserve_vocab,
size,
c.max_subitem,
c.min_subitem,
@ -74,58 +76,3 @@ def InitDataset(config):
)
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
return train_dataloader, val_dataloader
def InitValDataset(config):
val_batch_size = config.val_batch_size
num_proc = config.num_proc
if config.dataset.name == "special":
raw_dataset = SpecialDataset()
train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05])
val_dataloader = DataLoader(
val_dataset,
batch_size=val_batch_size,
num_workers=num_proc,
persistent_workers=True,
)
return val_dataloader
if config.dataset.name == "meaning":
c = config.dataset.meaning
vocab = config.model_config.vocab_size
start = c.start
size = c.size
end = c.end
seed = c.seed
path = "./data/"
conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}"
conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt"
valfile = path + f"MeaningDataset_val" + conf_name
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(valfile):
print(f"INFO: Load dataset from {valfile}")
val_dataset = torch.load(valfile, weights_only=False)
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
print(f"INFO: Load dataset end")
else:
raw_dataset = MeaningDataset(
start,
end,
vocab,
size,
c.max_subitem,
c.min_subitem,
stride=c.stride,
with_tree=c.with_trees,
seed=seed,
)
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(val_dataset, valfile)
print(f"INFO: Build and save dataset end")
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
return val_dataloader

View File

@ -20,6 +20,7 @@ class MeaningMap:
self,
size=1048576,
vocab_size=4096,
reserve_vocab=0,
max_subitem=10,
min_subitem=1,
stride=1,
@ -31,15 +32,17 @@ class MeaningMap:
assert min_subitem <= max_subitem, "Invalid input"
np.random.seed(seed)
self.size = size
self.reserve_vocab = reserve_vocab
self.special_vocab_size = 0
self.special_vocab = 0
if stride > 1:
self.special_vocab_size = self.special_vocab_size + 1
vocab_of_stride = vocab_size - self.special_vocab_size
self.special_vocab = self.special_vocab + 1
vocab_of_stride = vocab_size - self.special_vocab
if with_tree:
self.special_vocab_size = self.special_vocab_size + 1
vocab_of_tree = vocab_size - self.special_vocab_size
self.normal_vocab_size = vocab_size - self.special_vocab_size
self.special_vocab = self.special_vocab + 1
vocab_of_tree = vocab_size - self.special_vocab
assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special"
self.normal_vocab = vocab_size - self.reserve_vocab
self.max_subitem = max_subitem
self.min_subitem = min_subitem
@ -49,7 +52,7 @@ class MeaningMap:
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size)
file += "_" + str(max_subitem) + "_" + str(min_subitem)
file += "_" + str(reserve_vocab) + "_" + str(max_subitem) + "_" + str(min_subitem)
file += "_" + str(stride) + "_" + str(with_tree) + "_" + str(seed)
file_prop = path + file + "_prop.npy"
file_data = path + file + "_data.npy"
@ -100,8 +103,8 @@ class MeaningMap:
map[mask_zero] = -1
map[: self.normal_vocab_size, 0] = np.arange(0, self.normal_vocab_size)
map[: self.normal_vocab_size, 1:] = -1
map[: self.normal_vocab, 0] = np.arange(0, self.normal_vocab)
map[: self.normal_vocab, 1:] = -1
ms_level = [] # meaning level, vocab's level is 0
ms_rank_idx = [] # meaning index of all level
@ -118,7 +121,7 @@ class MeaningMap:
index = 0
for i in range(self.normal_vocab_size):
for i in range(self.normal_vocab):
ms_data[index] = i
ms_level[index] = 0
ms_rank_idx[index] = 0xFFFFFFF
@ -135,14 +138,14 @@ class MeaningMap:
ms_weight[i] = 1
index = index + stride
for i in range(self.normal_vocab_size, size):
for i in range(self.normal_vocab, size):
m = map[i] # 当前meaning的拆分的分支
m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m) # 当前meaning的拆分的分支个数
m_list = m.tolist()
assert m_list, "map list can not be empty list"
# 获取每个子meaning的start和end并且生成序列组合成当前meaning完整的叶index<self.normal_vocab_size)
# 获取每个子meaning的start和end并且生成序列组合成当前meaning完整的叶index<self.normal_vocab)
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
idxidx = np.concatenate(
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
@ -246,7 +249,7 @@ class MeaningMap:
root = NodeTree(str(meaning))
seqlist = []
get_tree_node(self.ms_map, meaning, self.normal_vocab_size, root, seqlist)
get_tree_node(self.ms_map, meaning, self.normal_vocab, root, seqlist)
root.seq_node = seqlist
return root
@ -256,7 +259,7 @@ class MeaningMap:
def level_change(ms_map, meaning, current_to_common, common_to_current):
ms = ms_map[meaning]
for m in ms[ms >= 0].tolist():
if m >= self.normal_vocab_size:
if m >= self.normal_vocab:
common_to_current[-1] = common_to_current[-1] + 1
level_change(ms_map, m, current_to_common, common_to_current)
else:
@ -319,6 +322,7 @@ class MeaningDataset(Dataset):
start,
end,
vocab_size,
reserve_vocab=0,
size=None,
max_subitem=10,
min_subitem=1,
@ -332,6 +336,7 @@ class MeaningDataset(Dataset):
self.start = start
self.end = end
self.vocab_size = vocab_size
self.reserve_vocab = reserve_vocab
self.max_subitem = max_subitem
self.min_subitem = min_subitem
self.stride = stride
@ -407,7 +412,14 @@ class MeaningDataset(Dataset):
def get_meaning_map(self):
return MeaningMap(
self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.stride, self.with_tree, self.use_cache
self.end,
self.vocab_size,
self.reserve_vocab,
self.max_subitem,
self.min_subitem,
self.stride,
self.with_tree,
self.use_cache,
)
def set_mask(self, level=None, idx=None):

View File

@ -44,7 +44,7 @@ if __name__ == "__main__":
qwen.llm.hook_attention = DumpQK
val = ds.InitValDataset(conf).dataset
_, val = ds.InitDataset(conf).dataset
md = val.meaning_dataset
map = md.get_meaning_map()