Refine mapping print.

This commit is contained in:
Colin 2024-04-03 13:03:59 +08:00
parent 1642a91d80
commit 3c774983d4
1 changed files with 34 additions and 17 deletions

View File

@ -140,7 +140,7 @@ class MeaningDataset(Dataset):
for m in meanings: for m in meanings:
sq = mm.get_sequence(m) sq = mm.get_sequence(m)
if len(sq) >= min_seq_len: if len(sq) >= min_seq_len:
self.mapping.append(mm.get_mapping(m)) self.mapping.append({m: mm.get_mapping(m)})
self.data.append(sq) self.data.append(sq)
self.length.append(len(sq)) self.length.append(len(sq))
@ -175,9 +175,41 @@ class MeaningDataset(Dataset):
output["token_type_ids"] = torch.zeros(data.shape) output["token_type_ids"] = torch.zeros(data.shape)
return output return output
def get_token_batch(self, index_list): # must equal sequence length
return [self.data[i] for i in index_list]
def print_token_batch(self, index_list): # must equal sequence length
data = [self.data[i] for i in index_list]
output = {}
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
return output
def get_mapping_batch(self, index_list): def get_mapping_batch(self, index_list):
return [self.mapping[i] for i in index_list] return [self.mapping[i] for i in index_list]
def __get_mapping_str__(map, prefix):
if isinstance(map, dict):
base = ""
for key, value in map.items():
base += prefix + str(key) + "\n"
base += MeaningDataset.__get_mapping_str__(value, prefix + " ")
return base
else:
return ""
def print_mapping_batch(self, index_list):
tokens = self.get_token_batch(index_list)
map = self.get_mapping_batch(index_list)
s = "--------------------------------------------------------\n"
for i, m in enumerate(map):
s += str(tokens[i]) + "\n"
s += MeaningDataset.__get_mapping_str__(m, "")
s += "--------------------------------------------------------\n"
return s
def split(self, ratio): def split(self, ratio):
l = len(self.data) l = len(self.data)
middle = int(l * ratio) middle = int(l * ratio)
@ -238,22 +270,7 @@ class BatchGroupMeaningDataloader(Dataset):
return self.dataset.get_mapping_batch(self.indexBatch[idx]) return self.dataset.get_mapping_batch(self.indexBatch[idx])
def print_mapping(self, idx): def print_mapping(self, idx):
map = self.mapping(idx) return self.dataset.print_mapping_batch(self.indexBatch[idx])
s = ""
for m in map:
s += BatchGroupMeaningDataloader.get_mapping_str(m, "")
s += "--------\n"
return s
def get_mapping_str(map, prefix):
if isinstance(map, dict):
base = ""
for key, value in map.items():
base += prefix + str(key) + "\n"
base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + " ")
return base
else:
return ""
if __name__ == "__main__": if __name__ == "__main__":