Refine mapping print.
This commit is contained in:
parent
1642a91d80
commit
3c774983d4
|
@ -140,7 +140,7 @@ class MeaningDataset(Dataset):
|
||||||
for m in meanings:
|
for m in meanings:
|
||||||
sq = mm.get_sequence(m)
|
sq = mm.get_sequence(m)
|
||||||
if len(sq) >= min_seq_len:
|
if len(sq) >= min_seq_len:
|
||||||
self.mapping.append(mm.get_mapping(m))
|
self.mapping.append({m: mm.get_mapping(m)})
|
||||||
self.data.append(sq)
|
self.data.append(sq)
|
||||||
self.length.append(len(sq))
|
self.length.append(len(sq))
|
||||||
|
|
||||||
|
@ -175,9 +175,41 @@ class MeaningDataset(Dataset):
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def get_token_batch(self, index_list): # must equal sequence length
|
||||||
|
return [self.data[i] for i in index_list]
|
||||||
|
|
||||||
|
def print_token_batch(self, index_list): # must equal sequence length
|
||||||
|
data = [self.data[i] for i in index_list]
|
||||||
|
output = {}
|
||||||
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
||||||
|
output["input_ids"] = data
|
||||||
|
output["labels"] = data.clone()
|
||||||
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
|
return output
|
||||||
|
|
||||||
def get_mapping_batch(self, index_list):
|
def get_mapping_batch(self, index_list):
|
||||||
return [self.mapping[i] for i in index_list]
|
return [self.mapping[i] for i in index_list]
|
||||||
|
|
||||||
|
def __get_mapping_str__(map, prefix):
|
||||||
|
if isinstance(map, dict):
|
||||||
|
base = ""
|
||||||
|
for key, value in map.items():
|
||||||
|
base += prefix + str(key) + "\n"
|
||||||
|
base += MeaningDataset.__get_mapping_str__(value, prefix + " ")
|
||||||
|
return base
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def print_mapping_batch(self, index_list):
|
||||||
|
tokens = self.get_token_batch(index_list)
|
||||||
|
map = self.get_mapping_batch(index_list)
|
||||||
|
s = "--------------------------------------------------------\n"
|
||||||
|
for i, m in enumerate(map):
|
||||||
|
s += str(tokens[i]) + "\n"
|
||||||
|
s += MeaningDataset.__get_mapping_str__(m, "")
|
||||||
|
s += "--------------------------------------------------------\n"
|
||||||
|
return s
|
||||||
|
|
||||||
def split(self, ratio):
|
def split(self, ratio):
|
||||||
l = len(self.data)
|
l = len(self.data)
|
||||||
middle = int(l * ratio)
|
middle = int(l * ratio)
|
||||||
|
@ -238,22 +270,7 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
return self.dataset.get_mapping_batch(self.indexBatch[idx])
|
return self.dataset.get_mapping_batch(self.indexBatch[idx])
|
||||||
|
|
||||||
def print_mapping(self, idx):
|
def print_mapping(self, idx):
|
||||||
map = self.mapping(idx)
|
return self.dataset.print_mapping_batch(self.indexBatch[idx])
|
||||||
s = ""
|
|
||||||
for m in map:
|
|
||||||
s += BatchGroupMeaningDataloader.get_mapping_str(m, "")
|
|
||||||
s += "--------\n"
|
|
||||||
return s
|
|
||||||
|
|
||||||
def get_mapping_str(map, prefix):
|
|
||||||
if isinstance(map, dict):
|
|
||||||
base = ""
|
|
||||||
for key, value in map.items():
|
|
||||||
base += prefix + str(key) + "\n"
|
|
||||||
base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + " ")
|
|
||||||
return base
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue