Add parent support in meaning_dataset.

This commit is contained in:
Colin 2025-08-04 16:10:46 +08:00
parent 8882073978
commit d7191003e0
2 changed files with 61 additions and 89 deletions

View File

@ -6,14 +6,13 @@ from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
from dataset.node_tree import NodeTree
# import warnings
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
from node_tree import NodeTree
class MeaningMap:
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42):
def __init__(
self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, with_parent=False, use_cache=True, seed=42
):
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
assert min_subitem <= max_subitem, "Invalid input"
np.random.seed(seed)
@ -21,6 +20,7 @@ class MeaningMap:
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
self.with_parent = with_parent
datastep = 0x8000000
path = "./data/"
@ -108,6 +108,8 @@ class MeaningMap:
m = map[i] # 当前meaning的拆分的分支
m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m) # 当前meaning的拆分的分支个数
if self.with_parent:
m_len = m_len + 1
m_list = m.tolist()
assert m_list, "map list can not be empty list"
@ -117,6 +119,8 @@ class MeaningMap:
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
)
len_ma = len(idx)
if self.with_parent:
len_ma = len_ma + 1
end = index + len_ma
if ms_data.size < end: # 超过存储数据结构的大小扩展一个datastep容量
@ -125,10 +129,26 @@ class MeaningMap:
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
ms_data[index:end] = ms_data[idx] # 拼接当前meaning的所有token到data数据结构里面
ms_level[index:end] = ms_level[idx] + 1 # 处理level
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) # 处理rank_idx
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) # 处理rank_all
# 拼接当前meaning的所有token到data数据结构里面
new_data = ms_data[idx]
if self.with_parent:
new_data = np.concatenate([new_data, np.array([i])])
ms_data[index:end] = new_data
# 处理level
new_level = ms_level[idx] + 1
if self.with_parent:
new_level = np.concatenate([new_level, np.array([256])])
ms_level[index:end] = new_level
# 处理rank_idx
new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
if self.with_parent:
new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])])
ms_rank_idx[index:end] = new_idx
# 处理rank_all
new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
if self.with_parent:
new_all = np.concatenate([new_all, np.array([0xFFFFFFF0 + m_len])])
ms_rank_all[index:end] = new_all
ms_start[i] = index
ms_end[i] = end
@ -275,6 +295,7 @@ class MeaningDataset(Dataset):
min_subitem=1,
min_seq_len=2,
seed=42,
with_parent=False,
use_cache=True,
):
np.random.seed(seed)
@ -283,6 +304,7 @@ class MeaningDataset(Dataset):
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
self.with_parent = with_parent
self.use_cache = use_cache
self.min_seq_len = min_seq_len
print("Build MeaningDataset from MeaningMap.")
@ -310,9 +332,11 @@ class MeaningDataset(Dataset):
self.seq_meaning.append(m)
seq_len.append(len(d))
origin_l = l.copy()
origin_l[l >= 255] = l[l >= 255] - 255
dm = np.ones(i.shape, dtype=np.uint32)
dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32)
shift = (8 - l) * 4
dm = ((dm * 0xFFFFFFFF) << (origin_l * 4)).astype(np.uint32)
shift = (8 - origin_l) * 4
rank_idx = (i & 0xF) << 28
rank_idx = rank_idx + ((i & 0xF0) << 20)
rank_idx = rank_idx + ((i & 0xF00) << 12)
@ -351,14 +375,16 @@ class MeaningDataset(Dataset):
return len(self.seq)
def get_meaning_map(self):
return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache)
return MeaningMap(
self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.with_parent, self.use_cache
)
def set_mask(self, level=None, idx=None):
if self.val_mask_level is not None and self.val_mask_idx is not None:
assert len(self.val_mask_level) > 0, "len must > 0"
assert len(self.val_mask_level) == len(self.val_mask_idx), "mask level and mask index must be same length"
assert isinstance(self.val_mask_level, list), "mask level must be list"
assert isinstance(self.val_mask_idx, list), "mask index must be list"
if level is not None and idx is not None:
assert len(level) > 0, "len must > 0"
assert len(level) == len(idx), "mask level and mask index must be same length"
assert isinstance(level, list), "mask level must be list"
assert isinstance(idx, list), "mask index must be list"
self.val_mask_level = level
self.val_mask_idx = idx
@ -403,7 +429,7 @@ class MeaningDataset(Dataset):
# assert level < 8, "level must < 8"
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
return rank_idx == (rank_all + index if index < 0 else index)
return (rank_idx == (rank_all + index if index < 0 else index)) & (self.level[idx] < 255)
def get_seq_mask_tensor(self, idx_list):
if self.val_mask_level is not None and self.val_mask_idx is not None:
@ -418,7 +444,7 @@ class MeaningDataset(Dataset):
)
return mask
else:
return None
return torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
class BatchGroupMeaningDataloader(Dataset):
@ -476,11 +502,19 @@ class BatchGroupMeaningDataloader(Dataset):
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
md.set_mask([1], [-1])
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, with_parent=True, use_cache=False)
train, val = md.split(0.95)
fdaf = md.__getitem__(920)
print(md.print_tree(920))
item = md.__getitem__(920)
mm = md.get_meaning_map()
mm.get_nodetree(int(item["meaning"][0])).print()
item_seq = item["input_ids"].numpy().tolist()
item_mask = item["val_mask"].numpy().tolist()
print(item_seq)
print(item_mask)
md.set_mask([1], [-1])
print(md.rank_idx[920])
print(md.rank_all[920])
mask = md.get_seq_mask(920, 0, -1)
@ -491,72 +525,6 @@ if __name__ == "__main__":
print(mask)
mask = md.get_seq_mask(920, 1, 1)
print(mask)
assert all(
np.equal(
mask[0:57],
np.array(
[
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
True,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
]
),
)
), "False"
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
md.set_mask([0, 1], [0, -1])

View File

@ -18,6 +18,10 @@ meaning数据集是一个模仿自然语言以及抽象表达的数据集。
12. meaning_height 当前meaning的总高度
13. meaning_weight 当前meaning的总宽度
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
15. with_parent 表示每个meaning的拆分都以当前meaning编号结束sequence会插入很多枝节点而不仅仅是叶节点
16. ms_level 大于255表示这个token是parentlevel 0用256表示生成val mask的时候是False
17. ms_map 表示每个meaning拆解的sub meaning
18. index must < 15level must < 8
## code