Add reserve_vocaba support, delete InitValDataset.
This commit is contained in:
parent
a983a2b6e6
commit
b2fe00c157
|
@ -40,6 +40,7 @@ class MeaningDatasetConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.start = 10000
|
self.start = 10000
|
||||||
self.end = 200000
|
self.end = 200000
|
||||||
|
self.reserve_vocab = 0
|
||||||
self.size = None
|
self.size = None
|
||||||
self.min_subitem = 2
|
self.min_subitem = 2
|
||||||
self.max_subitem = 10
|
self.max_subitem = 10
|
||||||
|
|
|
@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
runner = ModelRunner(qwen.llm)
|
runner = ModelRunner(qwen.llm)
|
||||||
|
|
||||||
val = ds.InitValDataset(conf).dataset
|
_, val = ds.InitDataset(conf).dataset
|
||||||
md = val.meaning_dataset
|
md = val.meaning_dataset
|
||||||
map = md.get_meaning_map()
|
map = md.get_meaning_map()
|
||||||
|
|
||||||
|
|
|
@ -1,2 +1 @@
|
||||||
from .dataset import InitDataset
|
from .dataset import InitDataset
|
||||||
from .dataset import InitValDataset
|
|
||||||
|
|
|
@ -35,10 +35,11 @@ def InitDataset(config):
|
||||||
size = c.size
|
size = c.size
|
||||||
end = c.end
|
end = c.end
|
||||||
seed = c.seed
|
seed = c.seed
|
||||||
|
reserve_vocab = c.reserve_vocab
|
||||||
|
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}"
|
conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}"
|
||||||
conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt"
|
conf_name = conf_name + f"_vocab{vocab}_reserve_vocab{reserve_vocab}_stride{c.stride}_tree{c.with_tree}.pt"
|
||||||
trainfile = path + f"MeaningDataset_train" + conf_name
|
trainfile = path + f"MeaningDataset_train" + conf_name
|
||||||
valfile = path + f"MeaningDataset_val" + conf_name
|
valfile = path + f"MeaningDataset_val" + conf_name
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
|
@ -56,6 +57,7 @@ def InitDataset(config):
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
vocab,
|
vocab,
|
||||||
|
reserve_vocab,
|
||||||
size,
|
size,
|
||||||
c.max_subitem,
|
c.max_subitem,
|
||||||
c.min_subitem,
|
c.min_subitem,
|
||||||
|
@ -74,58 +76,3 @@ def InitDataset(config):
|
||||||
)
|
)
|
||||||
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
||||||
return train_dataloader, val_dataloader
|
return train_dataloader, val_dataloader
|
||||||
|
|
||||||
|
|
||||||
def InitValDataset(config):
|
|
||||||
|
|
||||||
val_batch_size = config.val_batch_size
|
|
||||||
num_proc = config.num_proc
|
|
||||||
if config.dataset.name == "special":
|
|
||||||
raw_dataset = SpecialDataset()
|
|
||||||
train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05])
|
|
||||||
val_dataloader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=val_batch_size,
|
|
||||||
num_workers=num_proc,
|
|
||||||
persistent_workers=True,
|
|
||||||
)
|
|
||||||
return val_dataloader
|
|
||||||
|
|
||||||
if config.dataset.name == "meaning":
|
|
||||||
c = config.dataset.meaning
|
|
||||||
vocab = config.model_config.vocab_size
|
|
||||||
start = c.start
|
|
||||||
size = c.size
|
|
||||||
end = c.end
|
|
||||||
seed = c.seed
|
|
||||||
|
|
||||||
path = "./data/"
|
|
||||||
conf_name = f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_seed{seed}"
|
|
||||||
conf_name = conf_name + f"_vocab{vocab}_stride{c.stride}_tree{c.with_tree}.pt"
|
|
||||||
valfile = path + f"MeaningDataset_val" + conf_name
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.mkdir(path)
|
|
||||||
if os.path.exists(valfile):
|
|
||||||
print(f"INFO: Load dataset from {valfile}")
|
|
||||||
val_dataset = torch.load(valfile, weights_only=False)
|
|
||||||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
|
||||||
print(f"INFO: Load dataset end")
|
|
||||||
else:
|
|
||||||
raw_dataset = MeaningDataset(
|
|
||||||
start,
|
|
||||||
end,
|
|
||||||
vocab,
|
|
||||||
size,
|
|
||||||
c.max_subitem,
|
|
||||||
c.min_subitem,
|
|
||||||
stride=c.stride,
|
|
||||||
with_tree=c.with_trees,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
|
||||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
|
||||||
torch.save(val_dataset, valfile)
|
|
||||||
print(f"INFO: Build and save dataset end")
|
|
||||||
|
|
||||||
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
|
||||||
return val_dataloader
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ class MeaningMap:
|
||||||
self,
|
self,
|
||||||
size=1048576,
|
size=1048576,
|
||||||
vocab_size=4096,
|
vocab_size=4096,
|
||||||
|
reserve_vocab=0,
|
||||||
max_subitem=10,
|
max_subitem=10,
|
||||||
min_subitem=1,
|
min_subitem=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
@ -31,15 +32,17 @@ class MeaningMap:
|
||||||
assert min_subitem <= max_subitem, "Invalid input"
|
assert min_subitem <= max_subitem, "Invalid input"
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
self.size = size
|
self.size = size
|
||||||
|
self.reserve_vocab = reserve_vocab
|
||||||
|
|
||||||
self.special_vocab_size = 0
|
self.special_vocab = 0
|
||||||
if stride > 1:
|
if stride > 1:
|
||||||
self.special_vocab_size = self.special_vocab_size + 1
|
self.special_vocab = self.special_vocab + 1
|
||||||
vocab_of_stride = vocab_size - self.special_vocab_size
|
vocab_of_stride = vocab_size - self.special_vocab
|
||||||
if with_tree:
|
if with_tree:
|
||||||
self.special_vocab_size = self.special_vocab_size + 1
|
self.special_vocab = self.special_vocab + 1
|
||||||
vocab_of_tree = vocab_size - self.special_vocab_size
|
vocab_of_tree = vocab_size - self.special_vocab
|
||||||
self.normal_vocab_size = vocab_size - self.special_vocab_size
|
assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special"
|
||||||
|
self.normal_vocab = vocab_size - self.reserve_vocab
|
||||||
|
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
self.min_subitem = min_subitem
|
self.min_subitem = min_subitem
|
||||||
|
@ -49,7 +52,7 @@ class MeaningMap:
|
||||||
|
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
file = "structured_language_" + str(size) + "_" + str(vocab_size)
|
file = "structured_language_" + str(size) + "_" + str(vocab_size)
|
||||||
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
file += "_" + str(reserve_vocab) + "_" + str(max_subitem) + "_" + str(min_subitem)
|
||||||
file += "_" + str(stride) + "_" + str(with_tree) + "_" + str(seed)
|
file += "_" + str(stride) + "_" + str(with_tree) + "_" + str(seed)
|
||||||
file_prop = path + file + "_prop.npy"
|
file_prop = path + file + "_prop.npy"
|
||||||
file_data = path + file + "_data.npy"
|
file_data = path + file + "_data.npy"
|
||||||
|
@ -100,8 +103,8 @@ class MeaningMap:
|
||||||
|
|
||||||
map[mask_zero] = -1
|
map[mask_zero] = -1
|
||||||
|
|
||||||
map[: self.normal_vocab_size, 0] = np.arange(0, self.normal_vocab_size)
|
map[: self.normal_vocab, 0] = np.arange(0, self.normal_vocab)
|
||||||
map[: self.normal_vocab_size, 1:] = -1
|
map[: self.normal_vocab, 1:] = -1
|
||||||
|
|
||||||
ms_level = [] # meaning level, vocab's level is 0
|
ms_level = [] # meaning level, vocab's level is 0
|
||||||
ms_rank_idx = [] # meaning index of all level
|
ms_rank_idx = [] # meaning index of all level
|
||||||
|
@ -118,7 +121,7 @@ class MeaningMap:
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
|
|
||||||
for i in range(self.normal_vocab_size):
|
for i in range(self.normal_vocab):
|
||||||
ms_data[index] = i
|
ms_data[index] = i
|
||||||
ms_level[index] = 0
|
ms_level[index] = 0
|
||||||
ms_rank_idx[index] = 0xFFFFFFF
|
ms_rank_idx[index] = 0xFFFFFFF
|
||||||
|
@ -135,14 +138,14 @@ class MeaningMap:
|
||||||
ms_weight[i] = 1
|
ms_weight[i] = 1
|
||||||
index = index + stride
|
index = index + stride
|
||||||
|
|
||||||
for i in range(self.normal_vocab_size, size):
|
for i in range(self.normal_vocab, size):
|
||||||
m = map[i] # 当前meaning的拆分的分支
|
m = map[i] # 当前meaning的拆分的分支
|
||||||
m = m[m >= 0] # donot cut off the map such as [0]
|
m = m[m >= 0] # donot cut off the map such as [0]
|
||||||
m_len = len(m) # 当前meaning的拆分的分支个数
|
m_len = len(m) # 当前meaning的拆分的分支个数
|
||||||
m_list = m.tolist()
|
m_list = m.tolist()
|
||||||
assert m_list, "map list can not be empty list"
|
assert m_list, "map list can not be empty list"
|
||||||
|
|
||||||
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<self.normal_vocab_size)
|
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<self.normal_vocab)
|
||||||
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
||||||
idxidx = np.concatenate(
|
idxidx = np.concatenate(
|
||||||
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
||||||
|
@ -246,7 +249,7 @@ class MeaningMap:
|
||||||
|
|
||||||
root = NodeTree(str(meaning))
|
root = NodeTree(str(meaning))
|
||||||
seqlist = []
|
seqlist = []
|
||||||
get_tree_node(self.ms_map, meaning, self.normal_vocab_size, root, seqlist)
|
get_tree_node(self.ms_map, meaning, self.normal_vocab, root, seqlist)
|
||||||
root.seq_node = seqlist
|
root.seq_node = seqlist
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
@ -256,7 +259,7 @@ class MeaningMap:
|
||||||
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
||||||
ms = ms_map[meaning]
|
ms = ms_map[meaning]
|
||||||
for m in ms[ms >= 0].tolist():
|
for m in ms[ms >= 0].tolist():
|
||||||
if m >= self.normal_vocab_size:
|
if m >= self.normal_vocab:
|
||||||
common_to_current[-1] = common_to_current[-1] + 1
|
common_to_current[-1] = common_to_current[-1] + 1
|
||||||
level_change(ms_map, m, current_to_common, common_to_current)
|
level_change(ms_map, m, current_to_common, common_to_current)
|
||||||
else:
|
else:
|
||||||
|
@ -319,6 +322,7 @@ class MeaningDataset(Dataset):
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
|
reserve_vocab=0,
|
||||||
size=None,
|
size=None,
|
||||||
max_subitem=10,
|
max_subitem=10,
|
||||||
min_subitem=1,
|
min_subitem=1,
|
||||||
|
@ -332,6 +336,7 @@ class MeaningDataset(Dataset):
|
||||||
self.start = start
|
self.start = start
|
||||||
self.end = end
|
self.end = end
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self.reserve_vocab = reserve_vocab
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
self.min_subitem = min_subitem
|
self.min_subitem = min_subitem
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
@ -407,7 +412,14 @@ class MeaningDataset(Dataset):
|
||||||
|
|
||||||
def get_meaning_map(self):
|
def get_meaning_map(self):
|
||||||
return MeaningMap(
|
return MeaningMap(
|
||||||
self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.stride, self.with_tree, self.use_cache
|
self.end,
|
||||||
|
self.vocab_size,
|
||||||
|
self.reserve_vocab,
|
||||||
|
self.max_subitem,
|
||||||
|
self.min_subitem,
|
||||||
|
self.stride,
|
||||||
|
self.with_tree,
|
||||||
|
self.use_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_mask(self, level=None, idx=None):
|
def set_mask(self, level=None, idx=None):
|
||||||
|
|
|
@ -44,7 +44,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
qwen.llm.hook_attention = DumpQK
|
qwen.llm.hook_attention = DumpQK
|
||||||
|
|
||||||
val = ds.InitValDataset(conf).dataset
|
_, val = ds.InitDataset(conf).dataset
|
||||||
md = val.meaning_dataset
|
md = val.meaning_dataset
|
||||||
map = md.get_meaning_map()
|
map = md.get_meaning_map()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue