Update meaning of stride mask label.
This commit is contained in:
		
							parent
							
								
									3f0eedfef8
								
							
						
					
					
						commit
						2a09b9d9b1
					
				|  | @ -14,7 +14,7 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 | ||||||
| 8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充 | 8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充 | ||||||
| 9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充 | 9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充 | ||||||
| 10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构 | 10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构 | ||||||
| 11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个 | 11. get_rank_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个 | ||||||
| 12. meaning_height 当前meaning的总高度 | 12. meaning_height 当前meaning的总高度 | ||||||
| 13. meaning_weight 当前meaning的总宽度 | 13. meaning_weight 当前meaning的总宽度 | ||||||
| 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 | 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 | ||||||
|  |  | ||||||
|  | @ -429,9 +429,12 @@ class MeaningDataset(Dataset): | ||||||
|         output = {} |         output = {} | ||||||
|         data = torch.tensor(np.stack(data, axis=0)).long() |         data = torch.tensor(np.stack(data, axis=0)).long() | ||||||
|         output["input_ids"] = data |         output["input_ids"] = data | ||||||
|         output["labels"] = data.clone() |  | ||||||
|         output["token_type_ids"] = torch.zeros(data.shape) |         output["token_type_ids"] = torch.zeros(data.shape) | ||||||
|         output["val_mask"] = self.get_seq_mask_tensor(idx_list) |         val_mask, stride_mask = self.get_seq_mask_tensor(idx_list) | ||||||
|  |         output["val_mask"] = val_mask | ||||||
|  |         labels = data.clone() | ||||||
|  |         labels[~stride_mask] = self.vocab_size  # set to vocab_size will be masked in label | ||||||
|  |         output["labels"] = labels | ||||||
|         output["meaning"] = [self.seq_meaning[i] for i in idx_list] |         output["meaning"] = [self.seq_meaning[i] for i in idx_list] | ||||||
|         return output |         return output | ||||||
| 
 | 
 | ||||||
|  | @ -457,27 +460,22 @@ class MeaningDataset(Dataset): | ||||||
|         middle = int(l * ratio) |         middle = int(l * ratio) | ||||||
|         return self.copy(0, middle), self.copy(middle, l) |         return self.copy(0, middle), self.copy(middle, l) | ||||||
| 
 | 
 | ||||||
|     def get_seq_mask(self, idx, level, index): |     def get_rank_mask(self, idx, level, index): | ||||||
|         # assert index < 15, "index must < 15" |         # assert index < 15, "index must < 15" | ||||||
|         # assert level < 8, "level must < 8" |         # assert level < 8, "level must < 8" | ||||||
|         rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF |         rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF | ||||||
|         rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF |         rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF | ||||||
|         return (rank_idx == (rank_all + index if index < 0 else index)) & (self.level[idx] < 255) |         return rank_idx == (rank_all + index if index < 0 else index) | ||||||
| 
 | 
 | ||||||
|     def get_seq_mask_tensor(self, idx_list): |     def get_seq_mask_tensor(self, idx_list): | ||||||
|         if self.val_mask_level is not None and self.val_mask_idx is not None: |         stride_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0)) | ||||||
|             mask = torch.tensor( |         val_mask = stride_mask.clone() | ||||||
|                 np.stack( |         if self.val_mask_level and self.val_mask_idx: | ||||||
|                     [self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0 |             for i, l in enumerate(self.val_mask_level): | ||||||
|  |                 val_mask = val_mask & torch.tensor( | ||||||
|  |                     np.stack([self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list], axis=0) | ||||||
|                 ) |                 ) | ||||||
|             ) |         return (val_mask, stride_mask) | ||||||
|             for i, l in enumerate(self.val_mask_level[1:]): |  | ||||||
|                 mask = mask & torch.tensor( |  | ||||||
|                     np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0) |  | ||||||
|                 ) |  | ||||||
|             return mask |  | ||||||
|         else: |  | ||||||
|             return torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0)) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class BatchGroupMeaningDataloader(Dataset): | class BatchGroupMeaningDataloader(Dataset): | ||||||
|  | @ -586,13 +584,13 @@ if __name__ == "__main__": | ||||||
| 
 | 
 | ||||||
|     print(md.rank_idx[920]) |     print(md.rank_idx[920]) | ||||||
|     print(md.rank_all[920]) |     print(md.rank_all[920]) | ||||||
|     mask = md.get_seq_mask(920, 0, -1) |     mask = md.get_rank_mask(920, 0, -1) | ||||||
|     print(mask) |     print(mask) | ||||||
|     mask = md.get_seq_mask(920, 1, 0) |     mask = md.get_rank_mask(920, 1, 0) | ||||||
|     print(mask) |     print(mask) | ||||||
|     mask = md.get_seq_mask(920, 1, -1) |     mask = md.get_rank_mask(920, 1, -1) | ||||||
|     print(mask) |     print(mask) | ||||||
|     mask = md.get_seq_mask(920, 1, 1) |     mask = md.get_rank_mask(920, 1, 1) | ||||||
|     print(mask) |     print(mask) | ||||||
| 
 | 
 | ||||||
|     md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False) |     md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Colin
						Colin