Refine mapping print.
This commit is contained in:
		
							parent
							
								
									1642a91d80
								
							
						
					
					
						commit
						3c774983d4
					
				|  | @ -140,7 +140,7 @@ class MeaningDataset(Dataset): | |||
|         for m in meanings: | ||||
|             sq = mm.get_sequence(m) | ||||
|             if len(sq) >= min_seq_len: | ||||
|                 self.mapping.append(mm.get_mapping(m)) | ||||
|                 self.mapping.append({m: mm.get_mapping(m)}) | ||||
|                 self.data.append(sq) | ||||
|                 self.length.append(len(sq)) | ||||
| 
 | ||||
|  | @ -175,9 +175,41 @@ class MeaningDataset(Dataset): | |||
|         output["token_type_ids"] = torch.zeros(data.shape) | ||||
|         return output | ||||
| 
 | ||||
|     def get_token_batch(self, index_list):  # must equal sequence length | ||||
|         return [self.data[i] for i in index_list] | ||||
| 
 | ||||
|     def print_token_batch(self, index_list):  # must equal sequence length | ||||
|         data = [self.data[i] for i in index_list] | ||||
|         output = {} | ||||
|         data = torch.tensor(np.stack(data, axis=0)).long() | ||||
|         output["input_ids"] = data | ||||
|         output["labels"] = data.clone() | ||||
|         output["token_type_ids"] = torch.zeros(data.shape) | ||||
|         return output | ||||
| 
 | ||||
|     def get_mapping_batch(self, index_list): | ||||
|         return [self.mapping[i] for i in index_list] | ||||
| 
 | ||||
|     def __get_mapping_str__(map, prefix): | ||||
|         if isinstance(map, dict): | ||||
|             base = "" | ||||
|             for key, value in map.items(): | ||||
|                 base += prefix + str(key) + "\n" | ||||
|                 base += MeaningDataset.__get_mapping_str__(value, prefix + "    ") | ||||
|             return base | ||||
|         else: | ||||
|             return "" | ||||
| 
 | ||||
|     def print_mapping_batch(self, index_list): | ||||
|         tokens = self.get_token_batch(index_list) | ||||
|         map = self.get_mapping_batch(index_list) | ||||
|         s = "--------------------------------------------------------\n" | ||||
|         for i, m in enumerate(map): | ||||
|             s += str(tokens[i]) + "\n" | ||||
|             s += MeaningDataset.__get_mapping_str__(m, "") | ||||
|             s += "--------------------------------------------------------\n" | ||||
|         return s | ||||
| 
 | ||||
|     def split(self, ratio): | ||||
|         l = len(self.data) | ||||
|         middle = int(l * ratio) | ||||
|  | @ -238,22 +270,7 @@ class BatchGroupMeaningDataloader(Dataset): | |||
|         return self.dataset.get_mapping_batch(self.indexBatch[idx]) | ||||
| 
 | ||||
|     def print_mapping(self, idx): | ||||
|         map = self.mapping(idx) | ||||
|         s = "" | ||||
|         for m in map: | ||||
|             s += BatchGroupMeaningDataloader.get_mapping_str(m, "") | ||||
|             s += "--------\n" | ||||
|         return s | ||||
| 
 | ||||
|     def get_mapping_str(map, prefix): | ||||
|         if isinstance(map, dict): | ||||
|             base = "" | ||||
|             for key, value in map.items(): | ||||
|                 base += prefix + str(key) + "\n" | ||||
|                 base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + "    ") | ||||
|             return base | ||||
|         else: | ||||
|             return "" | ||||
|         return self.dataset.print_mapping_batch(self.indexBatch[idx]) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue