Refine mapping print.
This commit is contained in:
parent
1642a91d80
commit
3c774983d4
|
@ -140,7 +140,7 @@ class MeaningDataset(Dataset):
|
|||
for m in meanings:
|
||||
sq = mm.get_sequence(m)
|
||||
if len(sq) >= min_seq_len:
|
||||
self.mapping.append(mm.get_mapping(m))
|
||||
self.mapping.append({m: mm.get_mapping(m)})
|
||||
self.data.append(sq)
|
||||
self.length.append(len(sq))
|
||||
|
||||
|
@ -175,9 +175,41 @@ class MeaningDataset(Dataset):
|
|||
output["token_type_ids"] = torch.zeros(data.shape)
|
||||
return output
|
||||
|
||||
def get_token_batch(self, index_list): # must equal sequence length
|
||||
return [self.data[i] for i in index_list]
|
||||
|
||||
def print_token_batch(self, index_list): # must equal sequence length
|
||||
data = [self.data[i] for i in index_list]
|
||||
output = {}
|
||||
data = torch.tensor(np.stack(data, axis=0)).long()
|
||||
output["input_ids"] = data
|
||||
output["labels"] = data.clone()
|
||||
output["token_type_ids"] = torch.zeros(data.shape)
|
||||
return output
|
||||
|
||||
def get_mapping_batch(self, index_list):
|
||||
return [self.mapping[i] for i in index_list]
|
||||
|
||||
def __get_mapping_str__(map, prefix):
|
||||
if isinstance(map, dict):
|
||||
base = ""
|
||||
for key, value in map.items():
|
||||
base += prefix + str(key) + "\n"
|
||||
base += MeaningDataset.__get_mapping_str__(value, prefix + " ")
|
||||
return base
|
||||
else:
|
||||
return ""
|
||||
|
||||
def print_mapping_batch(self, index_list):
|
||||
tokens = self.get_token_batch(index_list)
|
||||
map = self.get_mapping_batch(index_list)
|
||||
s = "--------------------------------------------------------\n"
|
||||
for i, m in enumerate(map):
|
||||
s += str(tokens[i]) + "\n"
|
||||
s += MeaningDataset.__get_mapping_str__(m, "")
|
||||
s += "--------------------------------------------------------\n"
|
||||
return s
|
||||
|
||||
def split(self, ratio):
|
||||
l = len(self.data)
|
||||
middle = int(l * ratio)
|
||||
|
@ -238,22 +270,7 @@ class BatchGroupMeaningDataloader(Dataset):
|
|||
return self.dataset.get_mapping_batch(self.indexBatch[idx])
|
||||
|
||||
def print_mapping(self, idx):
|
||||
map = self.mapping(idx)
|
||||
s = ""
|
||||
for m in map:
|
||||
s += BatchGroupMeaningDataloader.get_mapping_str(m, "")
|
||||
s += "--------\n"
|
||||
return s
|
||||
|
||||
def get_mapping_str(map, prefix):
|
||||
if isinstance(map, dict):
|
||||
base = ""
|
||||
for key, value in map.items():
|
||||
base += prefix + str(key) + "\n"
|
||||
base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + " ")
|
||||
return base
|
||||
else:
|
||||
return ""
|
||||
return self.dataset.print_mapping_batch(self.indexBatch[idx])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue