Update meaning of stride mask label.

This commit is contained in:
Colin 2025-08-11 16:07:32 +08:00
parent 3f0eedfef8
commit 2a09b9d9b1
2 changed files with 19 additions and 21 deletions

View File

@ -14,7 +14,7 @@ meaning数据集是一个模仿自然语言以及抽象表达的数据集。
8. rank_idx表示当前token在不同层的排序编号每4位表示在一层里面的编号低4位表示最低层级的rank_idx高位无用的位用1填充 8. rank_idx表示当前token在不同层的排序编号每4位表示在一层里面的编号低4位表示最低层级的rank_idx高位无用的位用1填充
9. rank_all表示当前token所在的不同层的总的分支个数每4位表示在一层里面的个数低4位表示最低层级的rank_all高位无用的位用1填充 9. rank_all表示当前token所在的不同层的总的分支个数每4位表示在一层里面的个数低4位表示最低层级的rank_all高位无用的位用1填充
10. tree用于存储每个meaning的拆解的数据使用字典表达一个树形结构 10. tree用于存储每个meaning的拆解的数据使用字典表达一个树形结构
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层index=-1:最后一个index=0:第一个 11. get_rank_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层index=-1:最后一个index=0:第一个
12. meaning_height 当前meaning的总高度 12. meaning_height 当前meaning的总高度
13. meaning_weight 当前meaning的总宽度 13. meaning_weight 当前meaning的总宽度
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练

View File

@ -429,9 +429,12 @@ class MeaningDataset(Dataset):
output = {} output = {}
data = torch.tensor(np.stack(data, axis=0)).long() data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape) output["token_type_ids"] = torch.zeros(data.shape)
output["val_mask"] = self.get_seq_mask_tensor(idx_list) val_mask, stride_mask = self.get_seq_mask_tensor(idx_list)
output["val_mask"] = val_mask
labels = data.clone()
labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label
output["labels"] = labels
output["meaning"] = [self.seq_meaning[i] for i in idx_list] output["meaning"] = [self.seq_meaning[i] for i in idx_list]
return output return output
@ -457,27 +460,22 @@ class MeaningDataset(Dataset):
middle = int(l * ratio) middle = int(l * ratio)
return self.copy(0, middle), self.copy(middle, l) return self.copy(0, middle), self.copy(middle, l)
def get_seq_mask(self, idx, level, index): def get_rank_mask(self, idx, level, index):
# assert index < 15, "index must < 15" # assert index < 15, "index must < 15"
# assert level < 8, "level must < 8" # assert level < 8, "level must < 8"
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
return (rank_idx == (rank_all + index if index < 0 else index)) & (self.level[idx] < 255) return rank_idx == (rank_all + index if index < 0 else index)
def get_seq_mask_tensor(self, idx_list): def get_seq_mask_tensor(self, idx_list):
if self.val_mask_level is not None and self.val_mask_idx is not None: stride_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
mask = torch.tensor( val_mask = stride_mask.clone()
np.stack( if self.val_mask_level and self.val_mask_idx:
[self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0 for i, l in enumerate(self.val_mask_level):
val_mask = val_mask & torch.tensor(
np.stack([self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list], axis=0)
) )
) return (val_mask, stride_mask)
for i, l in enumerate(self.val_mask_level[1:]):
mask = mask & torch.tensor(
np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0)
)
return mask
else:
return torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
class BatchGroupMeaningDataloader(Dataset): class BatchGroupMeaningDataloader(Dataset):
@ -586,13 +584,13 @@ if __name__ == "__main__":
print(md.rank_idx[920]) print(md.rank_idx[920])
print(md.rank_all[920]) print(md.rank_all[920])
mask = md.get_seq_mask(920, 0, -1) mask = md.get_rank_mask(920, 0, -1)
print(mask) print(mask)
mask = md.get_seq_mask(920, 1, 0) mask = md.get_rank_mask(920, 1, 0)
print(mask) print(mask)
mask = md.get_seq_mask(920, 1, -1) mask = md.get_rank_mask(920, 1, -1)
print(mask) print(mask)
mask = md.get_seq_mask(920, 1, 1) mask = md.get_rank_mask(920, 1, 1)
print(mask) print(mask)
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False) md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)