diff --git a/wit/doc/meaning_dataset.md b/wit/doc/meaning_dataset.md index b6a47ba..0816db4 100644 --- a/wit/doc/meaning_dataset.md +++ b/wit/doc/meaning_dataset.md @@ -11,8 +11,8 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 5. meaning通过一层层的向低编号的meaning进行组合替换,最终形成一个最底层是token的树形数据 6. level表示当前token相对于root meaning的距离 7. rank -8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充 -9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充 +8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充,tree token总是排在最后面,index递增 +9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充,tree token不计入rank all的总数 10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构 11. get_rank_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个 12. meaning_height 当前meaning的总高度 diff --git a/wit/meaning/meaning_dataset.py b/wit/meaning/meaning_dataset.py index 597ddca..bc22f65 100644 --- a/wit/meaning/meaning_dataset.py +++ b/wit/meaning/meaning_dataset.py @@ -139,8 +139,6 @@ class MeaningMap: m = map[i] # 当前meaning的拆分的分支 m = m[m >= 0] # donot cut off the map such as [0] m_len = len(m) # 当前meaning的拆分的分支个数 - if self.with_tree: - m_len = m_len + 1 m_list = m.tolist() assert m_list, "map list can not be empty list" @@ -173,7 +171,7 @@ class MeaningMap: # 处理rank_idx new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) if self.with_tree: - new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])]) + new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len])]) ms_rank_idx[index:end] = new_idx # 处理rank_all new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) @@ -430,10 +428,10 @@ class MeaningDataset(Dataset): data = torch.tensor(np.stack(data, axis=0)).long() output["input_ids"] = data output["token_type_ids"] = torch.zeros(data.shape) - val_mask, stride_mask = self.get_seq_mask_tensor(idx_list) + val_mask, token_mask = self.get_seq_mask_tensor(idx_list) output["val_mask"] = val_mask labels = data.clone() - # labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label + # labels[~token_mask] = self.vocab_size # set to vocab_size will be masked in label output["labels"] = labels output["meaning"] = [self.seq_meaning[i] for i in idx_list] return output @@ -468,14 +466,13 @@ class MeaningDataset(Dataset): return rank_idx == (rank_all + index if index < 0 else index) def get_seq_mask_tensor(self, idx_list): - stride_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0)) - val_mask = stride_mask.clone() + token_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0)) + val_mask = token_mask.clone() if self.val_mask_level and self.val_mask_idx: for i, l in enumerate(self.val_mask_level): - val_mask = val_mask & torch.tensor( - np.stack([self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list], axis=0) - ) - return (val_mask, stride_mask) + mask = [self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list] + val_mask = val_mask & torch.tensor(np.stack(mask, axis=0)) + return (val_mask, token_mask) class BatchGroupMeaningDataloader(Dataset):