From a1d5fce3000987483b5e31b615a6c7403bc688b6 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 23 Feb 2025 16:46:37 +0800 Subject: [PATCH] Refine node tree define and print. --- wit/dataset/meaning_dataset.py | 3 ++- wit/dataset/node_tree.py | 25 +++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 427d030..6c0dd18 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -190,7 +190,8 @@ class MeaningMap: tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m return tree - return get_tree_dict(self.ms_map, self.vocab_size, meaning) + td = get_tree_dict(self.ms_map, self.vocab_size, meaning) + return {meaning: td} def max_length(self): return max(self.ms_len) diff --git a/wit/dataset/node_tree.py b/wit/dataset/node_tree.py index 59a80b2..ae20ed7 100644 --- a/wit/dataset/node_tree.py +++ b/wit/dataset/node_tree.py @@ -1,4 +1,21 @@ -from anytree import Node, RenderTree +from anytree import RenderTree + +from anytree.node.nodemixin import NodeMixin +from anytree.node.util import _repr + + +class Node(NodeMixin): + def __init__(self, name, parent=None, children=None, **kwargs): + self.__dict__.update(kwargs) + self.name = name + self.prop = "" + self.parent = parent + if children: + self.children = children + + def __repr__(self): + args = ["%r" % self.separator.join([""] + [str(node.name) for node in self.path])] + return _repr(self, args=args, nameblacklist=["name"]) class NodeTree: @@ -6,6 +23,7 @@ class NodeTree: self.tree = tree self.node = None self.seq_node = [] + self.get_node() def get_node(self): def get_tree_node(tree, parent, seqlist): @@ -25,10 +43,13 @@ class NodeTree: self.node = root.children[0] if len(self.tree) == 1 else root return self.node + def set_seq_prop(self, index, prop): + self.seq_node[index].prop = prop + def print(self): treestr = "" for pre, fill, node in RenderTree(self.get_node()): - treestr += f"{pre}{node.name}\n" + treestr += f"{pre}{node.name} {node.prop}\n" print(treestr) def print_tree(tree):