diff --git a/wit/doc/meaning_dataset.md b/wit/doc/meaning_dataset.md index 6d1fdf8..b6a47ba 100644 --- a/wit/doc/meaning_dataset.md +++ b/wit/doc/meaning_dataset.md @@ -14,7 +14,7 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充 9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充 10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构 -11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个 +11. get_rank_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个 12. meaning_height 当前meaning的总高度 13. meaning_weight 当前meaning的总宽度 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 diff --git a/wit/meaning/meaning_dataset.py b/wit/meaning/meaning_dataset.py index 0059d32..8c6f501 100644 --- a/wit/meaning/meaning_dataset.py +++ b/wit/meaning/meaning_dataset.py @@ -429,9 +429,12 @@ class MeaningDataset(Dataset): output = {} data = torch.tensor(np.stack(data, axis=0)).long() output["input_ids"] = data - output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) - output["val_mask"] = self.get_seq_mask_tensor(idx_list) + val_mask, stride_mask = self.get_seq_mask_tensor(idx_list) + output["val_mask"] = val_mask + labels = data.clone() + labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label + output["labels"] = labels output["meaning"] = [self.seq_meaning[i] for i in idx_list] return output @@ -457,27 +460,22 @@ class MeaningDataset(Dataset): middle = int(l * ratio) return self.copy(0, middle), self.copy(middle, l) - def get_seq_mask(self, idx, level, index): + def get_rank_mask(self, idx, level, index): # assert index < 15, "index must < 15" # assert level < 8, "level must < 8" rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF - return (rank_idx == (rank_all + index if index < 0 else index)) & (self.level[idx] < 255) + return rank_idx == (rank_all + index if index < 0 else index) def get_seq_mask_tensor(self, idx_list): - if self.val_mask_level is not None and self.val_mask_idx is not None: - mask = torch.tensor( - np.stack( - [self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0 + stride_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0)) + val_mask = stride_mask.clone() + if self.val_mask_level and self.val_mask_idx: + for i, l in enumerate(self.val_mask_level): + val_mask = val_mask & torch.tensor( + np.stack([self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list], axis=0) ) - ) - for i, l in enumerate(self.val_mask_level[1:]): - mask = mask & torch.tensor( - np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0) - ) - return mask - else: - return torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0)) + return (val_mask, stride_mask) class BatchGroupMeaningDataloader(Dataset): @@ -586,13 +584,13 @@ if __name__ == "__main__": print(md.rank_idx[920]) print(md.rank_all[920]) - mask = md.get_seq_mask(920, 0, -1) + mask = md.get_rank_mask(920, 0, -1) print(mask) - mask = md.get_seq_mask(920, 1, 0) + mask = md.get_rank_mask(920, 1, 0) print(mask) - mask = md.get_seq_mask(920, 1, -1) + mask = md.get_rank_mask(920, 1, -1) print(mask) - mask = md.get_seq_mask(920, 1, 1) + mask = md.get_rank_mask(920, 1, 1) print(mask) md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)