Add meaning map print.
This commit is contained in:
parent
89c12380cb
commit
1642a91d80
|
@ -237,10 +237,28 @@ class BatchGroupMeaningDataloader(Dataset):
|
|||
def mapping(self, idx):
|
||||
return self.dataset.get_mapping_batch(self.indexBatch[idx])
|
||||
|
||||
def print_mapping(self, idx):
|
||||
map = self.mapping(idx)
|
||||
s = ""
|
||||
for m in map:
|
||||
s += BatchGroupMeaningDataloader.get_mapping_str(m, "")
|
||||
s += "--------\n"
|
||||
return s
|
||||
|
||||
def get_mapping_str(map, prefix):
|
||||
if isinstance(map, dict):
|
||||
base = ""
|
||||
for key, value in map.items():
|
||||
base += prefix + str(key) + "\n"
|
||||
base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + " ")
|
||||
return base
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
md = MeaningDataset(1024, 115200, vocab_size=1024, size=1024)
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024)
|
||||
train, val = md.split(0.95)
|
||||
|
||||
dl = BatchGroupMeaningDataloader(train, 2)
|
||||
|
@ -250,7 +268,9 @@ if __name__ == "__main__":
|
|||
ne2 = next(it)
|
||||
ne3 = next(it)
|
||||
|
||||
map = dl.mapping(0)
|
||||
map1 = dl.mapping(0)
|
||||
map2 = dl.mapping(1)
|
||||
print(dl.print_mapping(0))
|
||||
|
||||
dl = DataLoader(
|
||||
train,
|
||||
|
|
|
@ -65,7 +65,6 @@ if __name__ == "__main__":
|
|||
torch.set_float32_matmul_precision("medium")
|
||||
lit_trainer = pl.Trainer(
|
||||
accelerator="cuda",
|
||||
devices=[0, 1],
|
||||
precision=precision,
|
||||
logger=TBLogger("./log/", name=name, version=ver, default_hp_metric=False),
|
||||
strategy=strategy,
|
||||
|
|
Loading…
Reference in New Issue