Update meaning of stride mask label.
This commit is contained in:
parent
3f0eedfef8
commit
2a09b9d9b1
|
@ -14,7 +14,7 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
|||
8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充
|
||||
9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充
|
||||
10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
|
||||
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
||||
11. get_rank_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
||||
12. meaning_height 当前meaning的总高度
|
||||
13. meaning_weight 当前meaning的总宽度
|
||||
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||
|
|
|
@ -429,9 +429,12 @@ class MeaningDataset(Dataset):
|
|||
output = {}
|
||||
data = torch.tensor(np.stack(data, axis=0)).long()
|
||||
output["input_ids"] = data
|
||||
output["labels"] = data.clone()
|
||||
output["token_type_ids"] = torch.zeros(data.shape)
|
||||
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
|
||||
val_mask, stride_mask = self.get_seq_mask_tensor(idx_list)
|
||||
output["val_mask"] = val_mask
|
||||
labels = data.clone()
|
||||
labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label
|
||||
output["labels"] = labels
|
||||
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
|
||||
return output
|
||||
|
||||
|
@ -457,27 +460,22 @@ class MeaningDataset(Dataset):
|
|||
middle = int(l * ratio)
|
||||
return self.copy(0, middle), self.copy(middle, l)
|
||||
|
||||
def get_seq_mask(self, idx, level, index):
|
||||
def get_rank_mask(self, idx, level, index):
|
||||
# assert index < 15, "index must < 15"
|
||||
# assert level < 8, "level must < 8"
|
||||
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||
return (rank_idx == (rank_all + index if index < 0 else index)) & (self.level[idx] < 255)
|
||||
return rank_idx == (rank_all + index if index < 0 else index)
|
||||
|
||||
def get_seq_mask_tensor(self, idx_list):
|
||||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
||||
mask = torch.tensor(
|
||||
np.stack(
|
||||
[self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0
|
||||
stride_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
|
||||
val_mask = stride_mask.clone()
|
||||
if self.val_mask_level and self.val_mask_idx:
|
||||
for i, l in enumerate(self.val_mask_level):
|
||||
val_mask = val_mask & torch.tensor(
|
||||
np.stack([self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list], axis=0)
|
||||
)
|
||||
)
|
||||
for i, l in enumerate(self.val_mask_level[1:]):
|
||||
mask = mask & torch.tensor(
|
||||
np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0)
|
||||
)
|
||||
return mask
|
||||
else:
|
||||
return torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
|
||||
return (val_mask, stride_mask)
|
||||
|
||||
|
||||
class BatchGroupMeaningDataloader(Dataset):
|
||||
|
@ -586,13 +584,13 @@ if __name__ == "__main__":
|
|||
|
||||
print(md.rank_idx[920])
|
||||
print(md.rank_all[920])
|
||||
mask = md.get_seq_mask(920, 0, -1)
|
||||
mask = md.get_rank_mask(920, 0, -1)
|
||||
print(mask)
|
||||
mask = md.get_seq_mask(920, 1, 0)
|
||||
mask = md.get_rank_mask(920, 1, 0)
|
||||
print(mask)
|
||||
mask = md.get_seq_mask(920, 1, -1)
|
||||
mask = md.get_rank_mask(920, 1, -1)
|
||||
print(mask)
|
||||
mask = md.get_seq_mask(920, 1, 1)
|
||||
mask = md.get_rank_mask(920, 1, 1)
|
||||
print(mask)
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
||||
|
|
Loading…
Reference in New Issue