Add reserve_vocaba support, delete InitValDataset.
This commit is contained in:
parent
a983a2b6e6
commit
b2fe00c157
|
@ -40,6 +40,7 @@ class MeaningDatasetConfig:
|
|||
def __init__(self):
|
||||
self.start = 10000
|
||||
self.end = 200000
|
||||
self.reserve_vocab = 0
|
||||
self.size = None
|
||||
self.min_subitem = 2
|
||||
self.max_subitem = 10
|
||||
|
|
|
@ -21,7 +21,7 @@ if __name__ == "__main__":
|
|||
|
||||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
val = ds.InitValDataset(conf).dataset
|
||||
_, val = ds.InitDataset(conf).dataset
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
|
||||
|
|
|
@ -1,2 +1 @@
|
|||
from .dataset import InitDataset
|
||||
from .dataset import InitValDataset
|
||||
|
|
|
@ -35,10 +35,11 @@ def InitDataset(config):
|
|||
size = c.size
|
||||
end = c.end
|
||||
seed = c.seed
|
||||
reserve_vocab = c.reserve_vocab
|
||||
|
||||
path = "./data/"
|
||||
conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}"
|
||||
conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt"
|
||||
conf_name = conf_name + f"_vocab{vocab}_reserve_vocab{reserve_vocab}_stride{c.stride}_tree{c.with_tree}.pt"
|
||||
trainfile = path + f"MeaningDataset_train" + conf_name
|
||||
valfile = path + f"MeaningDataset_val" + conf_name
|
||||
if not os.path.exists(path):
|
||||
|
@ -56,6 +57,7 @@ def InitDataset(config):
|
|||
start,
|
||||
end,
|
||||
vocab,
|
||||
reserve_vocab,
|
||||
size,
|
||||
c.max_subitem,
|
||||
c.min_subitem,
|
||||
|
@ -74,58 +76,3 @@ def InitDataset(config):
|
|||
)
|
||||
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
||||
return train_dataloader, val_dataloader
|
||||
|
||||
|
||||
def InitValDataset(config):
|
||||
|
||||
val_batch_size = config.val_batch_size
|
||||
num_proc = config.num_proc
|
||||
if config.dataset.name == "special":
|
||||
raw_dataset = SpecialDataset()
|
||||
train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05])
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=val_batch_size,
|
||||
num_workers=num_proc,
|
||||
persistent_workers=True,
|
||||
)
|
||||
return val_dataloader
|
||||
|
||||
if config.dataset.name == "meaning":
|
||||
c = config.dataset.meaning
|
||||
vocab = config.model_config.vocab_size
|
||||
start = c.start
|
||||
size = c.size
|
||||
end = c.end
|
||||
seed = c.seed
|
||||
|
||||
path = "./data/"
|
||||
conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}"
|
||||
conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt"
|
||||
valfile = path + f"MeaningDataset_val" + conf_name
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(valfile):
|
||||
print(f"INFO: Load dataset from {valfile}")
|
||||
val_dataset = torch.load(valfile, weights_only=False)
|
||||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
print(f"INFO: Load dataset end")
|
||||
else:
|
||||
raw_dataset = MeaningDataset(
|
||||
start,
|
||||
end,
|
||||
vocab,
|
||||
size,
|
||||
c.max_subitem,
|
||||
c.min_subitem,
|
||||
stride=c.stride,
|
||||
with_tree=c.with_trees,
|
||||
seed=seed,
|
||||
)
|
||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||
torch.save(val_dataset, valfile)
|
||||
print(f"INFO: Build and save dataset end")
|
||||
|
||||
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
||||
return val_dataloader
|
||||
|
|
|
@ -20,6 +20,7 @@ class MeaningMap:
|
|||
self,
|
||||
size=1048576,
|
||||
vocab_size=4096,
|
||||
reserve_vocab=0,
|
||||
max_subitem=10,
|
||||
min_subitem=1,
|
||||
stride=1,
|
||||
|
@ -31,15 +32,17 @@ class MeaningMap:
|
|||
assert min_subitem <= max_subitem, "Invalid input"
|
||||
np.random.seed(seed)
|
||||
self.size = size
|
||||
self.reserve_vocab = reserve_vocab
|
||||
|
||||
self.special_vocab_size = 0
|
||||
self.special_vocab = 0
|
||||
if stride > 1:
|
||||
self.special_vocab_size = self.special_vocab_size + 1
|
||||
vocab_of_stride = vocab_size - self.special_vocab_size
|
||||
self.special_vocab = self.special_vocab + 1
|
||||
vocab_of_stride = vocab_size - self.special_vocab
|
||||
if with_tree:
|
||||
self.special_vocab_size = self.special_vocab_size + 1
|
||||
vocab_of_tree = vocab_size - self.special_vocab_size
|
||||
self.normal_vocab_size = vocab_size - self.special_vocab_size
|
||||
self.special_vocab = self.special_vocab + 1
|
||||
vocab_of_tree = vocab_size - self.special_vocab
|
||||
assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special"
|
||||
self.normal_vocab = vocab_size - self.reserve_vocab
|
||||
|
||||
self.max_subitem = max_subitem
|
||||
self.min_subitem = min_subitem
|
||||
|
@ -49,7 +52,7 @@ class MeaningMap:
|
|||
|
||||
path = "./data/"
|
||||
file = "structured_language_" + str(size) + "_" + str(vocab_size)
|
||||
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
||||
file += "_" + str(reserve_vocab) + "_" + str(max_subitem) + "_" + str(min_subitem)
|
||||
file += "_" + str(stride) + "_" + str(with_tree) + "_" + str(seed)
|
||||
file_prop = path + file + "_prop.npy"
|
||||
file_data = path + file + "_data.npy"
|
||||
|
@ -100,8 +103,8 @@ class MeaningMap:
|
|||
|
||||
map[mask_zero] = -1
|
||||
|
||||
map[: self.normal_vocab_size, 0] = np.arange(0, self.normal_vocab_size)
|
||||
map[: self.normal_vocab_size, 1:] = -1
|
||||
map[: self.normal_vocab, 0] = np.arange(0, self.normal_vocab)
|
||||
map[: self.normal_vocab, 1:] = -1
|
||||
|
||||
ms_level = [] # meaning level, vocab's level is 0
|
||||
ms_rank_idx = [] # meaning index of all level
|
||||
|
@ -118,7 +121,7 @@ class MeaningMap:
|
|||
|
||||
index = 0
|
||||
|
||||
for i in range(self.normal_vocab_size):
|
||||
for i in range(self.normal_vocab):
|
||||
ms_data[index] = i
|
||||
ms_level[index] = 0
|
||||
ms_rank_idx[index] = 0xFFFFFFF
|
||||
|
@ -135,14 +138,14 @@ class MeaningMap:
|
|||
ms_weight[i] = 1
|
||||
index = index + stride
|
||||
|
||||
for i in range(self.normal_vocab_size, size):
|
||||
for i in range(self.normal_vocab, size):
|
||||
m = map[i] # 当前meaning的拆分的分支
|
||||
m = m[m >= 0] # donot cut off the map such as [0]
|
||||
m_len = len(m) # 当前meaning的拆分的分支个数
|
||||
m_list = m.tolist()
|
||||
assert m_list, "map list can not be empty list"
|
||||
|
||||
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<self.normal_vocab_size)
|
||||
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<self.normal_vocab)
|
||||
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
||||
idxidx = np.concatenate(
|
||||
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
||||
|
@ -246,7 +249,7 @@ class MeaningMap:
|
|||
|
||||
root = NodeTree(str(meaning))
|
||||
seqlist = []
|
||||
get_tree_node(self.ms_map, meaning, self.normal_vocab_size, root, seqlist)
|
||||
get_tree_node(self.ms_map, meaning, self.normal_vocab, root, seqlist)
|
||||
root.seq_node = seqlist
|
||||
return root
|
||||
|
||||
|
@ -256,7 +259,7 @@ class MeaningMap:
|
|||
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
||||
ms = ms_map[meaning]
|
||||
for m in ms[ms >= 0].tolist():
|
||||
if m >= self.normal_vocab_size:
|
||||
if m >= self.normal_vocab:
|
||||
common_to_current[-1] = common_to_current[-1] + 1
|
||||
level_change(ms_map, m, current_to_common, common_to_current)
|
||||
else:
|
||||
|
@ -319,6 +322,7 @@ class MeaningDataset(Dataset):
|
|||
start,
|
||||
end,
|
||||
vocab_size,
|
||||
reserve_vocab=0,
|
||||
size=None,
|
||||
max_subitem=10,
|
||||
min_subitem=1,
|
||||
|
@ -332,6 +336,7 @@ class MeaningDataset(Dataset):
|
|||
self.start = start
|
||||
self.end = end
|
||||
self.vocab_size = vocab_size
|
||||
self.reserve_vocab = reserve_vocab
|
||||
self.max_subitem = max_subitem
|
||||
self.min_subitem = min_subitem
|
||||
self.stride = stride
|
||||
|
@ -407,7 +412,14 @@ class MeaningDataset(Dataset):
|
|||
|
||||
def get_meaning_map(self):
|
||||
return MeaningMap(
|
||||
self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.stride, self.with_tree, self.use_cache
|
||||
self.end,
|
||||
self.vocab_size,
|
||||
self.reserve_vocab,
|
||||
self.max_subitem,
|
||||
self.min_subitem,
|
||||
self.stride,
|
||||
self.with_tree,
|
||||
self.use_cache,
|
||||
)
|
||||
|
||||
def set_mask(self, level=None, idx=None):
|
||||
|
|
|
@ -44,7 +44,7 @@ if __name__ == "__main__":
|
|||
|
||||
qwen.llm.hook_attention = DumpQK
|
||||
|
||||
val = ds.InitValDataset(conf).dataset
|
||||
_, val = ds.InitDataset(conf).dataset
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
|
||||
|
|
Loading…
Reference in New Issue