refine meaning dataset tree print.
This commit is contained in:
parent
81f9e54ca3
commit
f0469b351c
|
@ -6,6 +6,7 @@ from typing import Dict, Tuple
|
|||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler
|
||||
from anytree import Node, RenderTree
|
||||
|
||||
# import warnings
|
||||
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
||||
|
@ -191,27 +192,20 @@ class MeaningMap:
|
|||
def max_length(self):
|
||||
return max(self.ms_len)
|
||||
|
||||
def get_tree_str(tree, prefix):
|
||||
if isinstance(tree, list):
|
||||
base = ""
|
||||
for t in tree:
|
||||
base += MeaningMap.get_tree_str(t, "")
|
||||
return base
|
||||
else:
|
||||
if isinstance(tree, dict):
|
||||
base = ""
|
||||
last_is_dict = None
|
||||
def print_tree(tree):
|
||||
def get_tree_node(tree, parent=None):
|
||||
for key, value in tree.items():
|
||||
new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||
dict_string = MeaningMap.get_tree_str(value, new_prefix)
|
||||
if dict_string:
|
||||
base += "\n" + prefix + str(key) + ": " + dict_string
|
||||
last_is_dict = True
|
||||
else:
|
||||
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
|
||||
last_is_dict = False
|
||||
return base
|
||||
return None
|
||||
pn = Node(str(key), parent)
|
||||
if isinstance(value, dict):
|
||||
get_tree_node(value, pn)
|
||||
|
||||
assert isinstance(tree, dict)
|
||||
root = Node("root")
|
||||
get_tree_node(tree, root)
|
||||
treestr = ""
|
||||
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
|
||||
treestr += f"{pre}{node.name}\n"
|
||||
return treestr
|
||||
|
||||
def get_tree_indexed_str(tree, data, prefix):
|
||||
if isinstance(tree, list):
|
||||
|
@ -366,7 +360,7 @@ class MeaningDataset(Dataset):
|
|||
tokens = self.seq[idx]
|
||||
tree = self.get_tree(idx)
|
||||
s = str(tokens) + "\n"
|
||||
s += MeaningMap.get_tree_str(tree, "")
|
||||
s += MeaningMap.print_tree(tree)
|
||||
return s
|
||||
|
||||
def copy(self, start, end):
|
||||
|
@ -572,7 +566,7 @@ if __name__ == "__main__":
|
|||
ne1 = next(it)
|
||||
tree = ne1["tree"]
|
||||
mask = ne1["mask"].cpu().numpy()
|
||||
t = MeaningMap.get_tree_str(tree, "")
|
||||
t = MeaningMap.print_tree(tree)
|
||||
print(t)
|
||||
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
|
||||
print(m)
|
||||
|
|
Loading…
Reference in New Issue