Add meaning map print.

This commit is contained in:
Colin 2024-04-03 11:24:00 +08:00
parent 89c12380cb
commit 1642a91d80
2 changed files with 22 additions and 3 deletions

View File

@ -237,10 +237,28 @@ class BatchGroupMeaningDataloader(Dataset):
def mapping(self, idx): def mapping(self, idx):
return self.dataset.get_mapping_batch(self.indexBatch[idx]) return self.dataset.get_mapping_batch(self.indexBatch[idx])
def print_mapping(self, idx):
map = self.mapping(idx)
s = ""
for m in map:
s += BatchGroupMeaningDataloader.get_mapping_str(m, "")
s += "--------\n"
return s
def get_mapping_str(map, prefix):
if isinstance(map, dict):
base = ""
for key, value in map.items():
base += prefix + str(key) + "\n"
base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + " ")
return base
else:
return ""
if __name__ == "__main__": if __name__ == "__main__":
md = MeaningDataset(1024, 115200, vocab_size=1024, size=1024) md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024)
train, val = md.split(0.95) train, val = md.split(0.95)
dl = BatchGroupMeaningDataloader(train, 2) dl = BatchGroupMeaningDataloader(train, 2)
@ -250,7 +268,9 @@ if __name__ == "__main__":
ne2 = next(it) ne2 = next(it)
ne3 = next(it) ne3 = next(it)
map = dl.mapping(0) map1 = dl.mapping(0)
map2 = dl.mapping(1)
print(dl.print_mapping(0))
dl = DataLoader( dl = DataLoader(
train, train,

View File

@ -65,7 +65,6 @@ if __name__ == "__main__":
torch.set_float32_matmul_precision("medium") torch.set_float32_matmul_precision("medium")
lit_trainer = pl.Trainer( lit_trainer = pl.Trainer(
accelerator="cuda", accelerator="cuda",
devices=[0, 1],
precision=precision, precision=precision,
logger=TBLogger("./log/", name=name, version=ver, default_hp_metric=False), logger=TBLogger("./log/", name=name, version=ver, default_hp_metric=False),
strategy=strategy, strategy=strategy,