refine meaning dataset tree print.
This commit is contained in:
parent
81f9e54ca3
commit
f0469b351c
|
@ -6,6 +6,7 @@ from typing import Dict, Tuple
|
||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
|
from anytree import Node, RenderTree
|
||||||
|
|
||||||
# import warnings
|
# import warnings
|
||||||
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
||||||
|
@ -191,27 +192,20 @@ class MeaningMap:
|
||||||
def max_length(self):
|
def max_length(self):
|
||||||
return max(self.ms_len)
|
return max(self.ms_len)
|
||||||
|
|
||||||
def get_tree_str(tree, prefix):
|
def print_tree(tree):
|
||||||
if isinstance(tree, list):
|
def get_tree_node(tree, parent=None):
|
||||||
base = ""
|
|
||||||
for t in tree:
|
|
||||||
base += MeaningMap.get_tree_str(t, "")
|
|
||||||
return base
|
|
||||||
else:
|
|
||||||
if isinstance(tree, dict):
|
|
||||||
base = ""
|
|
||||||
last_is_dict = None
|
|
||||||
for key, value in tree.items():
|
for key, value in tree.items():
|
||||||
new_prefix = (len(str(key)) + 2) * " " + prefix
|
pn = Node(str(key), parent)
|
||||||
dict_string = MeaningMap.get_tree_str(value, new_prefix)
|
if isinstance(value, dict):
|
||||||
if dict_string:
|
get_tree_node(value, pn)
|
||||||
base += "\n" + prefix + str(key) + ": " + dict_string
|
|
||||||
last_is_dict = True
|
assert isinstance(tree, dict)
|
||||||
else:
|
root = Node("root")
|
||||||
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
|
get_tree_node(tree, root)
|
||||||
last_is_dict = False
|
treestr = ""
|
||||||
return base
|
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
|
||||||
return None
|
treestr += f"{pre}{node.name}\n"
|
||||||
|
return treestr
|
||||||
|
|
||||||
def get_tree_indexed_str(tree, data, prefix):
|
def get_tree_indexed_str(tree, data, prefix):
|
||||||
if isinstance(tree, list):
|
if isinstance(tree, list):
|
||||||
|
@ -366,7 +360,7 @@ class MeaningDataset(Dataset):
|
||||||
tokens = self.seq[idx]
|
tokens = self.seq[idx]
|
||||||
tree = self.get_tree(idx)
|
tree = self.get_tree(idx)
|
||||||
s = str(tokens) + "\n"
|
s = str(tokens) + "\n"
|
||||||
s += MeaningMap.get_tree_str(tree, "")
|
s += MeaningMap.print_tree(tree)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def copy(self, start, end):
|
def copy(self, start, end):
|
||||||
|
@ -572,7 +566,7 @@ if __name__ == "__main__":
|
||||||
ne1 = next(it)
|
ne1 = next(it)
|
||||||
tree = ne1["tree"]
|
tree = ne1["tree"]
|
||||||
mask = ne1["mask"].cpu().numpy()
|
mask = ne1["mask"].cpu().numpy()
|
||||||
t = MeaningMap.get_tree_str(tree, "")
|
t = MeaningMap.print_tree(tree)
|
||||||
print(t)
|
print(t)
|
||||||
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
|
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
|
||||||
print(m)
|
print(m)
|
||||||
|
|
Loading…
Reference in New Issue