Refine vocab_size of meaning dataset.
This commit is contained in:
parent
a70e19df5d
commit
133187d7bd
|
@ -15,6 +15,7 @@ else:
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size=1048576,
|
size=1048576,
|
||||||
|
@ -30,7 +31,16 @@ class MeaningMap:
|
||||||
assert min_subitem <= max_subitem, "Invalid input"
|
assert min_subitem <= max_subitem, "Invalid input"
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.vocab_size = vocab_size
|
|
||||||
|
self.special_vocab_size = 0
|
||||||
|
if stride > 1:
|
||||||
|
self.special_vocab_size = self.special_vocab_size + 1
|
||||||
|
vocab_of_stride = vocab_size - self.special_vocab_size
|
||||||
|
if with_tree:
|
||||||
|
self.special_vocab_size = self.special_vocab_size + 1
|
||||||
|
vocab_of_tree = vocab_size - self.special_vocab_size
|
||||||
|
self.normal_vocab_size = vocab_size - self.special_vocab_size
|
||||||
|
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
self.min_subitem = min_subitem
|
self.min_subitem = min_subitem
|
||||||
self.with_tree = with_tree
|
self.with_tree = with_tree
|
||||||
|
@ -89,8 +99,8 @@ class MeaningMap:
|
||||||
|
|
||||||
map[mask_zero] = -1
|
map[mask_zero] = -1
|
||||||
|
|
||||||
map[:vocab_size, 0] = np.arange(0, vocab_size)
|
map[: self.normal_vocab_size, 0] = np.arange(0, self.normal_vocab_size)
|
||||||
map[:vocab_size, 1:] = -1
|
map[: self.normal_vocab_size, 1:] = -1
|
||||||
|
|
||||||
ms_level = [] # meaning level, vocab's level is 0
|
ms_level = [] # meaning level, vocab's level is 0
|
||||||
ms_rank_idx = [] # meaning index of all level
|
ms_rank_idx = [] # meaning index of all level
|
||||||
|
@ -106,13 +116,14 @@ class MeaningMap:
|
||||||
ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level
|
ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
for i in range(self.vocab_size):
|
|
||||||
|
for i in range(self.normal_vocab_size):
|
||||||
ms_data[index] = i
|
ms_data[index] = i
|
||||||
ms_level[index] = 0
|
ms_level[index] = 0
|
||||||
ms_rank_idx[index] = 0xFFFFFFF
|
ms_rank_idx[index] = 0xFFFFFFF
|
||||||
ms_rank_all[index] = 0xFFFFFFF
|
ms_rank_all[index] = 0xFFFFFFF
|
||||||
for ind in range(index + 1, index + stride):
|
for ind in range(index + 1, index + stride):
|
||||||
ms_data[ind] = -1
|
ms_data[ind] = vocab_of_stride
|
||||||
ms_level[ind] = 511
|
ms_level[ind] = 511
|
||||||
ms_rank_idx[ind] = 0xFFFFFFF
|
ms_rank_idx[ind] = 0xFFFFFFF
|
||||||
ms_rank_all[ind] = 0xFFFFFFF
|
ms_rank_all[ind] = 0xFFFFFFF
|
||||||
|
@ -123,7 +134,7 @@ class MeaningMap:
|
||||||
ms_weight[i] = 1
|
ms_weight[i] = 1
|
||||||
index = index + stride
|
index = index + stride
|
||||||
|
|
||||||
for i in range(self.vocab_size, size):
|
for i in range(self.normal_vocab_size, size):
|
||||||
m = map[i] # 当前meaning的拆分的分支
|
m = map[i] # 当前meaning的拆分的分支
|
||||||
m = m[m >= 0] # donot cut off the map such as [0]
|
m = m[m >= 0] # donot cut off the map such as [0]
|
||||||
m_len = len(m) # 当前meaning的拆分的分支个数
|
m_len = len(m) # 当前meaning的拆分的分支个数
|
||||||
|
@ -132,7 +143,7 @@ class MeaningMap:
|
||||||
m_list = m.tolist()
|
m_list = m.tolist()
|
||||||
assert m_list, "map list can not be empty list"
|
assert m_list, "map list can not be empty list"
|
||||||
|
|
||||||
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<vocab_size)
|
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<self.normal_vocab_size)
|
||||||
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
||||||
idxidx = np.concatenate(
|
idxidx = np.concatenate(
|
||||||
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
||||||
|
@ -151,7 +162,7 @@ class MeaningMap:
|
||||||
# 拼接当前meaning的所有token到data数据结构里面
|
# 拼接当前meaning的所有token到data数据结构里面
|
||||||
new_data = ms_data[idx]
|
new_data = ms_data[idx]
|
||||||
if self.with_tree:
|
if self.with_tree:
|
||||||
new_data = np.concatenate([new_data, np.array([i])])
|
new_data = np.concatenate([new_data, np.array([vocab_of_tree])])
|
||||||
ms_data[index:end] = new_data
|
ms_data[index:end] = new_data
|
||||||
# 处理level
|
# 处理level
|
||||||
new_level = ms_level[idx] + 1
|
new_level = ms_level[idx] + 1
|
||||||
|
@ -224,19 +235,19 @@ class MeaningMap:
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_nodetree(self, meaning): # return meaning all sub items
|
def get_nodetree(self, meaning): # return meaning all sub items
|
||||||
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
|
def get_tree_node(ms_map, meaning, nvs, parent, seqlist):
|
||||||
ms = ms_map[meaning]
|
ms = ms_map[meaning]
|
||||||
for m in ms[ms >= 0].tolist():
|
for m in ms[ms >= 0].tolist():
|
||||||
if m >= self.vocab_size:
|
if m >= nvs:
|
||||||
pn = NodeTree(str(m), parent)
|
pn = NodeTree(str(m), parent)
|
||||||
get_tree_node(ms_map, m, vocab_size, pn, seqlist)
|
get_tree_node(ms_map, m, nvs, pn, seqlist)
|
||||||
else:
|
else:
|
||||||
pn = NodeTree("<" + str(m) + ">", parent)
|
pn = NodeTree("<" + str(m) + ">", parent)
|
||||||
seqlist.append(pn)
|
seqlist.append(pn)
|
||||||
|
|
||||||
root = NodeTree(str(meaning))
|
root = NodeTree(str(meaning))
|
||||||
seqlist = []
|
seqlist = []
|
||||||
get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist)
|
get_tree_node(self.ms_map, meaning, self.normal_vocab_size, root, seqlist)
|
||||||
root.seq_node = seqlist
|
root.seq_node = seqlist
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
@ -246,7 +257,7 @@ class MeaningMap:
|
||||||
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
||||||
ms = ms_map[meaning]
|
ms = ms_map[meaning]
|
||||||
for m in ms[ms >= 0].tolist():
|
for m in ms[ms >= 0].tolist():
|
||||||
if m >= self.vocab_size:
|
if m >= self.normal_vocab_size:
|
||||||
common_to_current[-1] = common_to_current[-1] + 1
|
common_to_current[-1] = common_to_current[-1] + 1
|
||||||
level_change(ms_map, m, current_to_common, common_to_current)
|
level_change(ms_map, m, current_to_common, common_to_current)
|
||||||
else:
|
else:
|
||||||
|
@ -526,8 +537,8 @@ if __name__ == "__main__":
|
||||||
tracemalloc.start()
|
tracemalloc.start()
|
||||||
|
|
||||||
md = MeaningDataset(
|
md = MeaningDataset(
|
||||||
|
10000,
|
||||||
100000,
|
100000,
|
||||||
300000,
|
|
||||||
min_subitem=2,
|
min_subitem=2,
|
||||||
max_subitem=6,
|
max_subitem=6,
|
||||||
vocab_size=32,
|
vocab_size=32,
|
||||||
|
@ -541,7 +552,16 @@ if __name__ == "__main__":
|
||||||
print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB")
|
print(f"峰值内存使用: {peak / 1024 / 1024 / 1024:.4f} GB")
|
||||||
tracemalloc.stop()
|
tracemalloc.stop()
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False)
|
md = MeaningDataset(
|
||||||
|
10000,
|
||||||
|
100000,
|
||||||
|
vocab_size=32,
|
||||||
|
stride=2,
|
||||||
|
min_subitem=2,
|
||||||
|
max_subitem=6,
|
||||||
|
with_tree=False,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
item = md.__getitem__(920)
|
item = md.__getitem__(920)
|
||||||
mm = md.get_meaning_map()
|
mm = md.get_meaning_map()
|
||||||
mm.get_nodetree(int(item["meaning"][0])).print()
|
mm.get_nodetree(int(item["meaning"][0])).print()
|
||||||
|
@ -550,7 +570,7 @@ if __name__ == "__main__":
|
||||||
print(item_seq)
|
print(item_seq)
|
||||||
print(item_mask)
|
print(item_mask)
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False)
|
md = MeaningDataset(10000, 100000, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False)
|
||||||
train, val = md.split(0.95)
|
train, val = md.split(0.95)
|
||||||
item = md.__getitem__(920)
|
item = md.__getitem__(920)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue