Update meaning of stride mask label.
This commit is contained in:
parent
3f0eedfef8
commit
2a09b9d9b1
|
@ -14,7 +14,7 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
||||||
8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充
|
8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充
|
||||||
9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充
|
9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充
|
||||||
10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
|
10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
|
||||||
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
11. get_rank_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
||||||
12. meaning_height 当前meaning的总高度
|
12. meaning_height 当前meaning的总高度
|
||||||
13. meaning_weight 当前meaning的总宽度
|
13. meaning_weight 当前meaning的总宽度
|
||||||
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||||
|
|
|
@ -429,9 +429,12 @@ class MeaningDataset(Dataset):
|
||||||
output = {}
|
output = {}
|
||||||
data = torch.tensor(np.stack(data, axis=0)).long()
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
||||||
output["input_ids"] = data
|
output["input_ids"] = data
|
||||||
output["labels"] = data.clone()
|
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
|
val_mask, stride_mask = self.get_seq_mask_tensor(idx_list)
|
||||||
|
output["val_mask"] = val_mask
|
||||||
|
labels = data.clone()
|
||||||
|
labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label
|
||||||
|
output["labels"] = labels
|
||||||
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
|
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -457,27 +460,22 @@ class MeaningDataset(Dataset):
|
||||||
middle = int(l * ratio)
|
middle = int(l * ratio)
|
||||||
return self.copy(0, middle), self.copy(middle, l)
|
return self.copy(0, middle), self.copy(middle, l)
|
||||||
|
|
||||||
def get_seq_mask(self, idx, level, index):
|
def get_rank_mask(self, idx, level, index):
|
||||||
# assert index < 15, "index must < 15"
|
# assert index < 15, "index must < 15"
|
||||||
# assert level < 8, "level must < 8"
|
# assert level < 8, "level must < 8"
|
||||||
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||||
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||||
return (rank_idx == (rank_all + index if index < 0 else index)) & (self.level[idx] < 255)
|
return rank_idx == (rank_all + index if index < 0 else index)
|
||||||
|
|
||||||
def get_seq_mask_tensor(self, idx_list):
|
def get_seq_mask_tensor(self, idx_list):
|
||||||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
stride_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
|
||||||
mask = torch.tensor(
|
val_mask = stride_mask.clone()
|
||||||
np.stack(
|
if self.val_mask_level and self.val_mask_idx:
|
||||||
[self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0
|
for i, l in enumerate(self.val_mask_level):
|
||||||
|
val_mask = val_mask & torch.tensor(
|
||||||
|
np.stack([self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list], axis=0)
|
||||||
)
|
)
|
||||||
)
|
return (val_mask, stride_mask)
|
||||||
for i, l in enumerate(self.val_mask_level[1:]):
|
|
||||||
mask = mask & torch.tensor(
|
|
||||||
np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0)
|
|
||||||
)
|
|
||||||
return mask
|
|
||||||
else:
|
|
||||||
return torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
|
|
||||||
|
|
||||||
|
|
||||||
class BatchGroupMeaningDataloader(Dataset):
|
class BatchGroupMeaningDataloader(Dataset):
|
||||||
|
@ -586,13 +584,13 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print(md.rank_idx[920])
|
print(md.rank_idx[920])
|
||||||
print(md.rank_all[920])
|
print(md.rank_all[920])
|
||||||
mask = md.get_seq_mask(920, 0, -1)
|
mask = md.get_rank_mask(920, 0, -1)
|
||||||
print(mask)
|
print(mask)
|
||||||
mask = md.get_seq_mask(920, 1, 0)
|
mask = md.get_rank_mask(920, 1, 0)
|
||||||
print(mask)
|
print(mask)
|
||||||
mask = md.get_seq_mask(920, 1, -1)
|
mask = md.get_rank_mask(920, 1, -1)
|
||||||
print(mask)
|
print(mask)
|
||||||
mask = md.get_seq_mask(920, 1, 1)
|
mask = md.get_rank_mask(920, 1, 1)
|
||||||
print(mask)
|
print(mask)
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
||||||
|
|
Loading…
Reference in New Issue