diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index e18b8f2..6d4057a 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -140,7 +140,7 @@ class MeaningDataset(Dataset): for m in meanings: sq = mm.get_sequence(m) if len(sq) >= min_seq_len: - self.mapping.append(mm.get_mapping(m)) + self.mapping.append({m: mm.get_mapping(m)}) self.data.append(sq) self.length.append(len(sq)) @@ -175,9 +175,41 @@ class MeaningDataset(Dataset): output["token_type_ids"] = torch.zeros(data.shape) return output + def get_token_batch(self, index_list): # must equal sequence length + return [self.data[i] for i in index_list] + + def print_token_batch(self, index_list): # must equal sequence length + data = [self.data[i] for i in index_list] + output = {} + data = torch.tensor(np.stack(data, axis=0)).long() + output["input_ids"] = data + output["labels"] = data.clone() + output["token_type_ids"] = torch.zeros(data.shape) + return output + def get_mapping_batch(self, index_list): return [self.mapping[i] for i in index_list] + def __get_mapping_str__(map, prefix): + if isinstance(map, dict): + base = "" + for key, value in map.items(): + base += prefix + str(key) + "\n" + base += MeaningDataset.__get_mapping_str__(value, prefix + " ") + return base + else: + return "" + + def print_mapping_batch(self, index_list): + tokens = self.get_token_batch(index_list) + map = self.get_mapping_batch(index_list) + s = "--------------------------------------------------------\n" + for i, m in enumerate(map): + s += str(tokens[i]) + "\n" + s += MeaningDataset.__get_mapping_str__(m, "") + s += "--------------------------------------------------------\n" + return s + def split(self, ratio): l = len(self.data) middle = int(l * ratio) @@ -238,22 +270,7 @@ class BatchGroupMeaningDataloader(Dataset): return self.dataset.get_mapping_batch(self.indexBatch[idx]) def print_mapping(self, idx): - map = self.mapping(idx) - s = "" - for m in map: - s += BatchGroupMeaningDataloader.get_mapping_str(m, "") - s += "--------\n" - return s - - def get_mapping_str(map, prefix): - if isinstance(map, dict): - base = "" - for key, value in map.items(): - base += prefix + str(key) + "\n" - base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + " ") - return base - else: - return "" + return self.dataset.print_mapping_batch(self.indexBatch[idx]) if __name__ == "__main__":