Fix rank index and all calculater when tree mode enable.

This commit is contained in:
Colin 2025-08-14 13:25:47 +08:00
parent 47ba18a4f9
commit a983a2b6e6
2 changed files with 10 additions and 13 deletions

View File

@ -11,8 +11,8 @@ meaning数据集是一个模仿自然语言以及抽象表达的数据集。
5. meaning通过一层层的向低编号的meaning进行组合替换最终形成一个最底层是token的树形数据
6. level表示当前token相对于root meaning的距离
7. rank
8. rank_idx表示当前token在不同层的排序编号每4位表示在一层里面的编号低4位表示最低层级的rank_idx高位无用的位用1填充
9. rank_all表示当前token所在的不同层的总的分支个数每4位表示在一层里面的个数低4位表示最低层级的rank_all高位无用的位用1填充
8. rank_idx表示当前token在不同层的排序编号每4位表示在一层里面的编号低4位表示最低层级的rank_idx高位无用的位用1填充,tree token总是排在最后面index递增
9. rank_all表示当前token所在的不同层的总的分支个数每4位表示在一层里面的个数低4位表示最低层级的rank_all高位无用的位用1填充tree token不计入rank all的总数
10. tree用于存储每个meaning的拆解的数据使用字典表达一个树形结构
11. get_rank_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层index=-1:最后一个index=0:第一个
12. meaning_height 当前meaning的总高度

View File

@ -139,8 +139,6 @@ class MeaningMap:
m = map[i] # 当前meaning的拆分的分支
m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m) # 当前meaning的拆分的分支个数
if self.with_tree:
m_len = m_len + 1
m_list = m.tolist()
assert m_list, "map list can not be empty list"
@ -173,7 +171,7 @@ class MeaningMap:
# 处理rank_idx
new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
if self.with_tree:
new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])])
new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len])])
ms_rank_idx[index:end] = new_idx
# 处理rank_all
new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
@ -430,10 +428,10 @@ class MeaningDataset(Dataset):
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["token_type_ids"] = torch.zeros(data.shape)
val_mask, stride_mask = self.get_seq_mask_tensor(idx_list)
val_mask, token_mask = self.get_seq_mask_tensor(idx_list)
output["val_mask"] = val_mask
labels = data.clone()
# labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label
# labels[~token_mask] = self.vocab_size # set to vocab_size will be masked in label
output["labels"] = labels
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
return output
@ -468,14 +466,13 @@ class MeaningDataset(Dataset):
return rank_idx == (rank_all + index if index < 0 else index)
def get_seq_mask_tensor(self, idx_list):
stride_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
val_mask = stride_mask.clone()
token_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
val_mask = token_mask.clone()
if self.val_mask_level and self.val_mask_idx:
for i, l in enumerate(self.val_mask_level):
val_mask = val_mask & torch.tensor(
np.stack([self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list], axis=0)
)
return (val_mask, stride_mask)
mask = [self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list]
val_mask = val_mask & torch.tensor(np.stack(mask, axis=0))
return (val_mask, token_mask)
class BatchGroupMeaningDataloader(Dataset):