refine meaning dataset tree print.

This commit is contained in:
Colin 2025-02-22 01:53:58 +08:00
parent 81f9e54ca3
commit f0469b351c
1 changed files with 17 additions and 23 deletions

View File

@ -6,6 +6,7 @@ from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
from anytree import Node, RenderTree
# import warnings
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
@ -191,27 +192,20 @@ class MeaningMap:
def max_length(self):
return max(self.ms_len)
def get_tree_str(tree, prefix):
if isinstance(tree, list):
base = ""
for t in tree:
base += MeaningMap.get_tree_str(t, "")
return base
else:
if isinstance(tree, dict):
base = ""
last_is_dict = None
def print_tree(tree):
def get_tree_node(tree, parent=None):
for key, value in tree.items():
new_prefix = (len(str(key)) + 2) * " " + prefix
dict_string = MeaningMap.get_tree_str(value, new_prefix)
if dict_string:
base += "\n" + prefix + str(key) + ": " + dict_string
last_is_dict = True
else:
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
last_is_dict = False
return base
return None
pn = Node(str(key), parent)
if isinstance(value, dict):
get_tree_node(value, pn)
assert isinstance(tree, dict)
root = Node("root")
get_tree_node(tree, root)
treestr = ""
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
treestr += f"{pre}{node.name}\n"
return treestr
def get_tree_indexed_str(tree, data, prefix):
if isinstance(tree, list):
@ -366,7 +360,7 @@ class MeaningDataset(Dataset):
tokens = self.seq[idx]
tree = self.get_tree(idx)
s = str(tokens) + "\n"
s += MeaningMap.get_tree_str(tree, "")
s += MeaningMap.print_tree(tree)
return s
def copy(self, start, end):
@ -572,7 +566,7 @@ if __name__ == "__main__":
ne1 = next(it)
tree = ne1["tree"]
mask = ne1["mask"].cpu().numpy()
t = MeaningMap.get_tree_str(tree, "")
t = MeaningMap.print_tree(tree)
print(t)
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
print(m)