From 1642a91d803d34625af322701157e40860109204 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 3 Apr 2024 11:24:00 +0800 Subject: [PATCH] Add meaning map print. --- wit/meaning_dataset.py | 24 ++++++++++++++++++++++-- wit/train.py | 1 - 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index 0fbf344..e18b8f2 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -237,10 +237,28 @@ class BatchGroupMeaningDataloader(Dataset): def mapping(self, idx): return self.dataset.get_mapping_batch(self.indexBatch[idx]) + def print_mapping(self, idx): + map = self.mapping(idx) + s = "" + for m in map: + s += BatchGroupMeaningDataloader.get_mapping_str(m, "") + s += "--------\n" + return s + + def get_mapping_str(map, prefix): + if isinstance(map, dict): + base = "" + for key, value in map.items(): + base += prefix + str(key) + "\n" + base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + " ") + return base + else: + return "" + if __name__ == "__main__": - md = MeaningDataset(1024, 115200, vocab_size=1024, size=1024) + md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024) train, val = md.split(0.95) dl = BatchGroupMeaningDataloader(train, 2) @@ -250,7 +268,9 @@ if __name__ == "__main__": ne2 = next(it) ne3 = next(it) - map = dl.mapping(0) + map1 = dl.mapping(0) + map2 = dl.mapping(1) + print(dl.print_mapping(0)) dl = DataLoader( train, diff --git a/wit/train.py b/wit/train.py index a59b378..6baec8a 100644 --- a/wit/train.py +++ b/wit/train.py @@ -65,7 +65,6 @@ if __name__ == "__main__": torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer( accelerator="cuda", - devices=[0, 1], precision=precision, logger=TBLogger("./log/", name=name, version=ver, default_hp_metric=False), strategy=strategy,