From f0469b351c63c6167dcc464f054ac86ab0490ccf Mon Sep 17 00:00:00 2001 From: Colin Date: Sat, 22 Feb 2025 01:53:58 +0800 Subject: [PATCH] refine meaning dataset tree print. --- wit/dataset/meaning_dataset.py | 40 +++++++++++++++------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 56c5400..8d74fe2 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -6,6 +6,7 @@ from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler +from anytree import Node, RenderTree # import warnings # warnings.filterwarnings("ignore", ".*does not have many workers.*") @@ -191,27 +192,20 @@ class MeaningMap: def max_length(self): return max(self.ms_len) - def get_tree_str(tree, prefix): - if isinstance(tree, list): - base = "" - for t in tree: - base += MeaningMap.get_tree_str(t, "") - return base - else: - if isinstance(tree, dict): - base = "" - last_is_dict = None - for key, value in tree.items(): - new_prefix = (len(str(key)) + 2) * " " + prefix - dict_string = MeaningMap.get_tree_str(value, new_prefix) - if dict_string: - base += "\n" + prefix + str(key) + ": " + dict_string - last_is_dict = True - else: - base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " " - last_is_dict = False - return base - return None + def print_tree(tree): + def get_tree_node(tree, parent=None): + for key, value in tree.items(): + pn = Node(str(key), parent) + if isinstance(value, dict): + get_tree_node(value, pn) + + assert isinstance(tree, dict) + root = Node("root") + get_tree_node(tree, root) + treestr = "" + for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root): + treestr += f"{pre}{node.name}\n" + return treestr def get_tree_indexed_str(tree, data, prefix): if isinstance(tree, list): @@ -366,7 +360,7 @@ class MeaningDataset(Dataset): tokens = self.seq[idx] tree = self.get_tree(idx) s = str(tokens) + "\n" - s += MeaningMap.get_tree_str(tree, "") + s += MeaningMap.print_tree(tree) return s def copy(self, start, end): @@ -572,7 +566,7 @@ if __name__ == "__main__": ne1 = next(it) tree = ne1["tree"] mask = ne1["mask"].cpu().numpy() - t = MeaningMap.get_tree_str(tree, "") + t = MeaningMap.print_tree(tree) print(t) m, l = MeaningMap.get_tree_indexed_str(tree, mask, "") print(m)