Refine vocab_size of meaning dataset.

This commit is contained in:
Colin 2025-08-10 21:01:26 +08:00
parent a70e19df5d
commit 133187d7bd
1 changed files with 36 additions and 16 deletions

View File

@ -15,6 +15,7 @@ else:
class MeaningMap:
def __init__(
self,
size=1048576,
@ -30,7 +31,16 @@ class MeaningMap:
assert min_subitem <= max_subitem, "Invalid input"
np.random.seed(seed)
self.size = size
self.vocab_size = vocab_size
self.special_vocab_size = 0
if stride > 1:
self.special_vocab_size = self.special_vocab_size + 1
vocab_of_stride = vocab_size - self.special_vocab_size
if with_tree:
self.special_vocab_size = self.special_vocab_size + 1
vocab_of_tree = vocab_size - self.special_vocab_size
self.normal_vocab_size = vocab_size - self.special_vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
self.with_tree = with_tree
@ -89,8 +99,8 @@ class MeaningMap:
map[mask_zero] = -1
map[:vocab_size, 0] = np.arange(0, vocab_size)
map[:vocab_size, 1:] = -1
map[: self.normal_vocab_size, 0] = np.arange(0, self.normal_vocab_size)
map[: self.normal_vocab_size, 1:] = -1
ms_level = [] # meaning level, vocab's level is 0
ms_rank_idx = [] # meaning index of all level
@ -106,13 +116,14 @@ class MeaningMap:
ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level
index = 0
for i in range(self.vocab_size):
for i in range(self.normal_vocab_size):
ms_data[index] = i
ms_level[index] = 0
ms_rank_idx[index] = 0xFFFFFFF
ms_rank_all[index] = 0xFFFFFFF
for ind in range(index + 1, index + stride):
ms_data[ind] = -1
ms_data[ind] = vocab_of_stride
ms_level[ind] = 511
ms_rank_idx[ind] = 0xFFFFFFF
ms_rank_all[ind] = 0xFFFFFFF
@ -123,7 +134,7 @@ class MeaningMap:
ms_weight[i] = 1
index = index + stride
for i in range(self.vocab_size, size):
for i in range(self.normal_vocab_size, size):
m = map[i] # 当前meaning的拆分的分支
m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m) # 当前meaning的拆分的分支个数
@ -132,7 +143,7 @@ class MeaningMap:
m_list = m.tolist()
assert m_list, "map list can not be empty list"
# 获取每个子meaning的start和end并且生成序列组合成当前meaning完整的叶index<vocab_size)
# 获取每个子meaning的start和end并且生成序列组合成当前meaning完整的叶index<self.normal_vocab_size)
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
idxidx = np.concatenate(
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
@ -151,7 +162,7 @@ class MeaningMap:
# 拼接当前meaning的所有token到data数据结构里面
new_data = ms_data[idx]
if self.with_tree:
new_data = np.concatenate([new_data, np.array([i])])
new_data = np.concatenate([new_data, np.array([vocab_of_tree])])
ms_data[index:end] = new_data
# 处理level
new_level = ms_level[idx] + 1
@ -224,19 +235,19 @@ class MeaningMap:
)
def get_nodetree(self, meaning): # return meaning all sub items
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
def get_tree_node(ms_map, meaning, nvs, parent, seqlist):
ms = ms_map[meaning]
for m in ms[ms >= 0].tolist():
if m >= self.vocab_size:
if m >= nvs:
pn = NodeTree(str(m), parent)
get_tree_node(ms_map, m, vocab_size, pn, seqlist)
get_tree_node(ms_map, m, nvs, pn, seqlist)
else:
pn = NodeTree("<" + str(m) + ">", parent)
seqlist.append(pn)
root = NodeTree(str(meaning))
seqlist = []
get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist)
get_tree_node(self.ms_map, meaning, self.normal_vocab_size, root, seqlist)
root.seq_node = seqlist
return root
@ -246,7 +257,7 @@ class MeaningMap:
def level_change(ms_map, meaning, current_to_common, common_to_current):
ms = ms_map[meaning]
for m in ms[ms >= 0].tolist():
if m >= self.vocab_size:
if m >= self.normal_vocab_size:
common_to_current[-1] = common_to_current[-1] + 1
level_change(ms_map, m, current_to_common, common_to_current)
else:
@ -526,8 +537,8 @@ if __name__ == "__main__":
tracemalloc.start()
md = MeaningDataset(
10000,
100000,
300000,
min_subitem=2,
max_subitem=6,
vocab_size=32,
@ -541,7 +552,16 @@ if __name__ == "__main__":
print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB")
tracemalloc.stop()
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False)
md = MeaningDataset(
10000,
100000,
vocab_size=32,
stride=2,
min_subitem=2,
max_subitem=6,
with_tree=False,
use_cache=False,
)
item = md.__getitem__(920)
mm = md.get_meaning_map()
mm.get_nodetree(int(item["meaning"][0])).print()
@ -550,7 +570,7 @@ if __name__ == "__main__":
print(item_seq)
print(item_mask)
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False)
md = MeaningDataset(10000, 100000, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False)
train, val = md.split(0.95)
item = md.__getitem__(920)