Add stride support in meaning_dataset.

This commit is contained in:
Colin 2025-08-07 17:36:43 +08:00
parent d7191003e0
commit 9c75b8920d
2 changed files with 47 additions and 22 deletions

View File

@ -11,7 +11,15 @@ from node_tree import NodeTree
class MeaningMap: class MeaningMap:
def __init__( def __init__(
self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, with_parent=False, use_cache=True, seed=42 self,
size=1048576,
vocab_size=4096,
max_subitem=10,
min_subitem=1,
stride=1,
with_tree=False,
use_cache=True,
seed=42,
): ):
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input" assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
assert min_subitem <= max_subitem, "Invalid input" assert min_subitem <= max_subitem, "Invalid input"
@ -20,7 +28,8 @@ class MeaningMap:
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_subitem = max_subitem self.max_subitem = max_subitem
self.min_subitem = min_subitem self.min_subitem = min_subitem
self.with_parent = with_parent self.with_tree = with_tree
self.stride = stride
datastep = 0x8000000 datastep = 0x8000000
path = "./data/" path = "./data/"
@ -93,22 +102,27 @@ class MeaningMap:
index = 0 index = 0
for i in range(self.vocab_size): for i in range(self.vocab_size):
ms_data[i] = i ms_data[index] = i
ms_level[index] = 0
ms_rank_idx[index] = 0xFFFFFFF
ms_rank_all[index] = 0xFFFFFFF
for ind in range(index + 1, index + stride):
ms_data[ind] = -1
ms_level[ind] = 511
ms_rank_idx[ind] = 0xFFFFFFF
ms_rank_all[ind] = 0xFFFFFFF
ms_start[i] = index ms_start[i] = index
ms_end[i] = index + 1 ms_end[i] = index + stride
ms_len[i] = 1 ms_len[i] = stride
ms_level[i] = 0
ms_rank_idx[i] = 0xFFFFFFF
ms_rank_all[i] = 0xFFFFFFF
ms_height[i] = 0 ms_height[i] = 0
ms_weight[i] = 1 ms_weight[i] = 1
index = index + 1 index = index + stride
for i in range(self.vocab_size, size): for i in range(self.vocab_size, size):
m = map[i] # 当前meaning的拆分的分支 m = map[i] # 当前meaning的拆分的分支
m = m[m >= 0] # donot cut off the map such as [0] m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m) # 当前meaning的拆分的分支个数 m_len = len(m) # 当前meaning的拆分的分支个数
if self.with_parent: if self.with_tree:
m_len = m_len + 1 m_len = m_len + 1
m_list = m.tolist() m_list = m.tolist()
assert m_list, "map list can not be empty list" assert m_list, "map list can not be empty list"
@ -119,7 +133,7 @@ class MeaningMap:
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])] [np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
) )
len_ma = len(idx) len_ma = len(idx)
if self.with_parent: if self.with_tree:
len_ma = len_ma + 1 len_ma = len_ma + 1
end = index + len_ma end = index + len_ma
@ -131,22 +145,22 @@ class MeaningMap:
# 拼接当前meaning的所有token到data数据结构里面 # 拼接当前meaning的所有token到data数据结构里面
new_data = ms_data[idx] new_data = ms_data[idx]
if self.with_parent: if self.with_tree:
new_data = np.concatenate([new_data, np.array([i])]) new_data = np.concatenate([new_data, np.array([i])])
ms_data[index:end] = new_data ms_data[index:end] = new_data
# 处理level # 处理level
new_level = ms_level[idx] + 1 new_level = ms_level[idx] + 1
if self.with_parent: if self.with_tree:
new_level = np.concatenate([new_level, np.array([256])]) new_level = np.concatenate([new_level, np.array([255 + 1])])
ms_level[index:end] = new_level ms_level[index:end] = new_level
# 处理rank_idx # 处理rank_idx
new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
if self.with_parent: if self.with_tree:
new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])]) new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])])
ms_rank_idx[index:end] = new_idx ms_rank_idx[index:end] = new_idx
# 处理rank_all # 处理rank_all
new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
if self.with_parent: if self.with_tree:
new_all = np.concatenate([new_all, np.array([0xFFFFFFF0 + m_len])]) new_all = np.concatenate([new_all, np.array([0xFFFFFFF0 + m_len])])
ms_rank_all[index:end] = new_all ms_rank_all[index:end] = new_all
@ -295,7 +309,8 @@ class MeaningDataset(Dataset):
min_subitem=1, min_subitem=1,
min_seq_len=2, min_seq_len=2,
seed=42, seed=42,
with_parent=False, stride=1,
with_tree=False,
use_cache=True, use_cache=True,
): ):
np.random.seed(seed) np.random.seed(seed)
@ -304,7 +319,8 @@ class MeaningDataset(Dataset):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_subitem = max_subitem self.max_subitem = max_subitem
self.min_subitem = min_subitem self.min_subitem = min_subitem
self.with_parent = with_parent self.stride = stride
self.with_tree = with_tree
self.use_cache = use_cache self.use_cache = use_cache
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
print("Build MeaningDataset from MeaningMap.") print("Build MeaningDataset from MeaningMap.")
@ -376,7 +392,7 @@ class MeaningDataset(Dataset):
def get_meaning_map(self): def get_meaning_map(self):
return MeaningMap( return MeaningMap(
self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.with_parent, self.use_cache self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.stride, self.with_tree, self.use_cache
) )
def set_mask(self, level=None, idx=None): def set_mask(self, level=None, idx=None):
@ -502,7 +518,16 @@ class BatchGroupMeaningDataloader(Dataset):
if __name__ == "__main__": if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, with_parent=True, use_cache=False) md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False)
item = md.__getitem__(920)
mm = md.get_meaning_map()
mm.get_nodetree(int(item["meaning"][0])).print()
item_seq = item["input_ids"].numpy().tolist()
item_mask = item["val_mask"].numpy().tolist()
print(item_seq)
print(item_mask)
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False)
train, val = md.split(0.95) train, val = md.split(0.95)
item = md.__getitem__(920) item = md.__getitem__(920)

View File

@ -18,8 +18,8 @@ meaning数据集是一个模仿自然语言以及抽象表达的数据集。
12. meaning_height 当前meaning的总高度 12. meaning_height 当前meaning的总高度
13. meaning_weight 当前meaning的总宽度 13. meaning_weight 当前meaning的总宽度
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
15. with_parent 表示每个meaning的拆分都以当前meaning编号结束sequence会插入很多枝节点而不仅仅是叶节点 15. with_tree 表示每个meaning的拆分都以当前meaning编号结束sequence会插入很多枝节点而不仅仅是叶节点
16. ms_level 大于255表示这个token是parentlevel 0用256表示生成val mask的时候是False 16. ms_level 大于255表示这个token是treelevel 0用255表示大于511表示这个token是stridelevel 0用511表示生成val mask的时候是False
17. ms_map 表示每个meaning拆解的sub meaning 17. ms_map 表示每个meaning拆解的sub meaning
18. index must < 15level must < 8 18. index must < 15level must < 8