Refine vocab_size of meaning dataset.
This commit is contained in:
parent
a70e19df5d
commit
133187d7bd
|
@ -15,6 +15,7 @@ else:
|
|||
|
||||
|
||||
class MeaningMap:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size=1048576,
|
||||
|
@ -30,7 +31,16 @@ class MeaningMap:
|
|||
assert min_subitem <= max_subitem, "Invalid input"
|
||||
np.random.seed(seed)
|
||||
self.size = size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.special_vocab_size = 0
|
||||
if stride > 1:
|
||||
self.special_vocab_size = self.special_vocab_size + 1
|
||||
vocab_of_stride = vocab_size - self.special_vocab_size
|
||||
if with_tree:
|
||||
self.special_vocab_size = self.special_vocab_size + 1
|
||||
vocab_of_tree = vocab_size - self.special_vocab_size
|
||||
self.normal_vocab_size = vocab_size - self.special_vocab_size
|
||||
|
||||
self.max_subitem = max_subitem
|
||||
self.min_subitem = min_subitem
|
||||
self.with_tree = with_tree
|
||||
|
@ -89,8 +99,8 @@ class MeaningMap:
|
|||
|
||||
map[mask_zero] = -1
|
||||
|
||||
map[:vocab_size, 0] = np.arange(0, vocab_size)
|
||||
map[:vocab_size, 1:] = -1
|
||||
map[: self.normal_vocab_size, 0] = np.arange(0, self.normal_vocab_size)
|
||||
map[: self.normal_vocab_size, 1:] = -1
|
||||
|
||||
ms_level = [] # meaning level, vocab's level is 0
|
||||
ms_rank_idx = [] # meaning index of all level
|
||||
|
@ -106,13 +116,14 @@ class MeaningMap:
|
|||
ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level
|
||||
|
||||
index = 0
|
||||
for i in range(self.vocab_size):
|
||||
|
||||
for i in range(self.normal_vocab_size):
|
||||
ms_data[index] = i
|
||||
ms_level[index] = 0
|
||||
ms_rank_idx[index] = 0xFFFFFFF
|
||||
ms_rank_all[index] = 0xFFFFFFF
|
||||
for ind in range(index + 1, index + stride):
|
||||
ms_data[ind] = -1
|
||||
ms_data[ind] = vocab_of_stride
|
||||
ms_level[ind] = 511
|
||||
ms_rank_idx[ind] = 0xFFFFFFF
|
||||
ms_rank_all[ind] = 0xFFFFFFF
|
||||
|
@ -123,7 +134,7 @@ class MeaningMap:
|
|||
ms_weight[i] = 1
|
||||
index = index + stride
|
||||
|
||||
for i in range(self.vocab_size, size):
|
||||
for i in range(self.normal_vocab_size, size):
|
||||
m = map[i] # 当前meaning的拆分的分支
|
||||
m = m[m >= 0] # donot cut off the map such as [0]
|
||||
m_len = len(m) # 当前meaning的拆分的分支个数
|
||||
|
@ -132,7 +143,7 @@ class MeaningMap:
|
|||
m_list = m.tolist()
|
||||
assert m_list, "map list can not be empty list"
|
||||
|
||||
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<vocab_size)
|
||||
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<self.normal_vocab_size)
|
||||
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
||||
idxidx = np.concatenate(
|
||||
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
||||
|
@ -151,7 +162,7 @@ class MeaningMap:
|
|||
# 拼接当前meaning的所有token到data数据结构里面
|
||||
new_data = ms_data[idx]
|
||||
if self.with_tree:
|
||||
new_data = np.concatenate([new_data, np.array([i])])
|
||||
new_data = np.concatenate([new_data, np.array([vocab_of_tree])])
|
||||
ms_data[index:end] = new_data
|
||||
# 处理level
|
||||
new_level = ms_level[idx] + 1
|
||||
|
@ -224,19 +235,19 @@ class MeaningMap:
|
|||
)
|
||||
|
||||
def get_nodetree(self, meaning): # return meaning all sub items
|
||||
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
|
||||
def get_tree_node(ms_map, meaning, nvs, parent, seqlist):
|
||||
ms = ms_map[meaning]
|
||||
for m in ms[ms >= 0].tolist():
|
||||
if m >= self.vocab_size:
|
||||
if m >= nvs:
|
||||
pn = NodeTree(str(m), parent)
|
||||
get_tree_node(ms_map, m, vocab_size, pn, seqlist)
|
||||
get_tree_node(ms_map, m, nvs, pn, seqlist)
|
||||
else:
|
||||
pn = NodeTree("<" + str(m) + ">", parent)
|
||||
seqlist.append(pn)
|
||||
|
||||
root = NodeTree(str(meaning))
|
||||
seqlist = []
|
||||
get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist)
|
||||
get_tree_node(self.ms_map, meaning, self.normal_vocab_size, root, seqlist)
|
||||
root.seq_node = seqlist
|
||||
return root
|
||||
|
||||
|
@ -246,7 +257,7 @@ class MeaningMap:
|
|||
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
||||
ms = ms_map[meaning]
|
||||
for m in ms[ms >= 0].tolist():
|
||||
if m >= self.vocab_size:
|
||||
if m >= self.normal_vocab_size:
|
||||
common_to_current[-1] = common_to_current[-1] + 1
|
||||
level_change(ms_map, m, current_to_common, common_to_current)
|
||||
else:
|
||||
|
@ -526,8 +537,8 @@ if __name__ == "__main__":
|
|||
tracemalloc.start()
|
||||
|
||||
md = MeaningDataset(
|
||||
10000,
|
||||
100000,
|
||||
300000,
|
||||
min_subitem=2,
|
||||
max_subitem=6,
|
||||
vocab_size=32,
|
||||
|
@ -541,7 +552,16 @@ if __name__ == "__main__":
|
|||
print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB")
|
||||
tracemalloc.stop()
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False)
|
||||
md = MeaningDataset(
|
||||
10000,
|
||||
100000,
|
||||
vocab_size=32,
|
||||
stride=2,
|
||||
min_subitem=2,
|
||||
max_subitem=6,
|
||||
with_tree=False,
|
||||
use_cache=False,
|
||||
)
|
||||
item = md.__getitem__(920)
|
||||
mm = md.get_meaning_map()
|
||||
mm.get_nodetree(int(item["meaning"][0])).print()
|
||||
|
@ -550,7 +570,7 @@ if __name__ == "__main__":
|
|||
print(item_seq)
|
||||
print(item_mask)
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False)
|
||||
md = MeaningDataset(10000, 100000, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False)
|
||||
train, val = md.split(0.95)
|
||||
item = md.__getitem__(920)
|
||||
|
||||
|
|
Loading…
Reference in New Issue