Add meaning map print.
This commit is contained in:
parent
89c12380cb
commit
1642a91d80
|
@ -237,10 +237,28 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
def mapping(self, idx):
|
def mapping(self, idx):
|
||||||
return self.dataset.get_mapping_batch(self.indexBatch[idx])
|
return self.dataset.get_mapping_batch(self.indexBatch[idx])
|
||||||
|
|
||||||
|
def print_mapping(self, idx):
|
||||||
|
map = self.mapping(idx)
|
||||||
|
s = ""
|
||||||
|
for m in map:
|
||||||
|
s += BatchGroupMeaningDataloader.get_mapping_str(m, "")
|
||||||
|
s += "--------\n"
|
||||||
|
return s
|
||||||
|
|
||||||
|
def get_mapping_str(map, prefix):
|
||||||
|
if isinstance(map, dict):
|
||||||
|
base = ""
|
||||||
|
for key, value in map.items():
|
||||||
|
base += prefix + str(key) + "\n"
|
||||||
|
base += BatchGroupMeaningDataloader.get_mapping_str(value, prefix + " ")
|
||||||
|
return base
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(1024, 115200, vocab_size=1024, size=1024)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024)
|
||||||
train, val = md.split(0.95)
|
train, val = md.split(0.95)
|
||||||
|
|
||||||
dl = BatchGroupMeaningDataloader(train, 2)
|
dl = BatchGroupMeaningDataloader(train, 2)
|
||||||
|
@ -250,7 +268,9 @@ if __name__ == "__main__":
|
||||||
ne2 = next(it)
|
ne2 = next(it)
|
||||||
ne3 = next(it)
|
ne3 = next(it)
|
||||||
|
|
||||||
map = dl.mapping(0)
|
map1 = dl.mapping(0)
|
||||||
|
map2 = dl.mapping(1)
|
||||||
|
print(dl.print_mapping(0))
|
||||||
|
|
||||||
dl = DataLoader(
|
dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
|
|
|
@ -65,7 +65,6 @@ if __name__ == "__main__":
|
||||||
torch.set_float32_matmul_precision("medium")
|
torch.set_float32_matmul_precision("medium")
|
||||||
lit_trainer = pl.Trainer(
|
lit_trainer = pl.Trainer(
|
||||||
accelerator="cuda",
|
accelerator="cuda",
|
||||||
devices=[0, 1],
|
|
||||||
precision=precision,
|
precision=precision,
|
||||||
logger=TBLogger("./log/", name=name, version=ver, default_hp_metric=False),
|
logger=TBLogger("./log/", name=name, version=ver, default_hp_metric=False),
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
|
|
Loading…
Reference in New Issue