Refine mapping print.
This commit is contained in:
		
							parent
							
								
									1642a91d80
								
							
						
					
					
						commit
						3c774983d4
					
				| 
						 | 
					@ -140,7 +140,7 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        for m in meanings:
 | 
					        for m in meanings:
 | 
				
			||||||
            sq = mm.get_sequence(m)
 | 
					            sq = mm.get_sequence(m)
 | 
				
			||||||
            if len(sq) >= min_seq_len:
 | 
					            if len(sq) >= min_seq_len:
 | 
				
			||||||
                self.mapping.append(mm.get_mapping(m))
 | 
					                self.mapping.append({m: mm.get_mapping(m)})
 | 
				
			||||||
                self.data.append(sq)
 | 
					                self.data.append(sq)
 | 
				
			||||||
                self.length.append(len(sq))
 | 
					                self.length.append(len(sq))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -175,9 +175,41 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        output["token_type_ids"] = torch.zeros(data.shape)
 | 
					        output["token_type_ids"] = torch.zeros(data.shape)
 | 
				
			||||||
        return output
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_token_batch(self, index_list):  # must equal sequence length
 | 
				
			||||||
 | 
					        return [self.data[i] for i in index_list]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def print_token_batch(self, index_list):  # must equal sequence length
 | 
				
			||||||
 | 
					        data = [self.data[i] for i in index_list]
 | 
				
			||||||
 | 
					        output = {}
 | 
				
			||||||
 | 
					        data = torch.tensor(np.stack(data, axis=0)).long()
 | 
				
			||||||
 | 
					        output["input_ids"] = data
 | 
				
			||||||
 | 
					        output["labels"] = data.clone()
 | 
				
			||||||
 | 
					        output["token_type_ids"] = torch.zeros(data.shape)
 | 
				
			||||||
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_mapping_batch(self, index_list):
 | 
					    def get_mapping_batch(self, index_list):
 | 
				
			||||||
        return [self.mapping[i] for i in index_list]
 | 
					        return [self.mapping[i] for i in index_list]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __get_mapping_str__(map, prefix):
 | 
				
			||||||
 | 
					        if isinstance(map, dict):
 | 
				
			||||||
 | 
					            base = ""
 | 
				
			||||||
 | 
					            for key, value in map.items():
 | 
				
			||||||
 | 
					                base += prefix + str(key) + "\n"
 | 
				
			||||||
 | 
					                base += MeaningDataset.__get_mapping_str__(value, prefix + "    ")
 | 
				
			||||||
 | 
					            return base
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def print_mapping_batch(self, index_list):
 | 
				
			||||||
 | 
					        tokens = self.get_token_batch(index_list)
 | 
				
			||||||
 | 
					        map = self.get_mapping_batch(index_list)
 | 
				
			||||||
 | 
					        s = "--------------------------------------------------------\n"
 | 
				
			||||||
 | 
					        for i, m in enumerate(map):
 | 
				
			||||||
 | 
					            s += str(tokens[i]) + "\n"
 | 
				
			||||||
 | 
					            s += MeaningDataset.__get_mapping_str__(m, "")
 | 
				
			||||||
 | 
					            s += "--------------------------------------------------------\n"
 | 
				
			||||||
 | 
					        return s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def split(self, ratio):
 | 
					    def split(self, ratio):
 | 
				
			||||||
        l = len(self.data)
 | 
					        l = len(self.data)
 | 
				
			||||||
        middle = int(l * ratio)
 | 
					        middle = int(l * ratio)
 | 
				
			||||||
| 
						 | 
					@ -238,22 +270,7 @@ class BatchGroupMeaningDataloader(Dataset):
 | 
				
			||||||
        return self.dataset.get_mapping_batch(self.indexBatch[idx])
 | 
					        return self.dataset.get_mapping_batch(self.indexBatch[idx])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def print_mapping(self, idx):
 | 
					    def print_mapping(self, idx):
 | 
				
			||||||
        map = self.mapping(idx)
 | 
					        return self.dataset.print_mapping_batch(self.indexBatch[idx])
 | 
				
			||||||
        s = ""
 | 
					 | 
				
			||||||
        for m in map:
 | 
					 | 
				
			||||||
            s += BatchGroupMeaningDataloader.get_mapping_str(m, "")
 | 
					 | 
				
			||||||
            s += "--------\n"
 | 
					 | 
				
			||||||
        return s
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_mapping_str(map, prefix):
 | 
					 | 
				
			||||||
        if isinstance(map, dict):
 | 
					 | 
				
			||||||
            base = ""
 | 
					 | 
				
			||||||
            for key, value in map.items():
 | 
					 | 
				
			||||||
                base += prefix + str(key) + "\n"
 | 
					 | 
				
			||||||
                base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + "    ")
 | 
					 | 
				
			||||||
            return base
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return ""
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue