Witllm/wit/dataset/node_tree.py

105 lines
3.6 KiB
Python

from anytree import RenderTree
from anytree.node.nodemixin import NodeMixin
from anytree.node.util import _repr
class Node(NodeMixin):
def __init__(self, name, parent=None, children=None, **kwargs):
self.__dict__.update(kwargs)
self.name = name
self.prop = ""
self.parent = parent
if children:
self.children = children
def __repr__(self):
args = ["%r" % self.separator.join([""] + [str(node.name) for node in self.path])]
return _repr(self, args=args, nameblacklist=["name"])
class NodeTree:
def __init__(self, tree: dict):
self.tree = tree
self.node = None
self.seq_node = []
self.get_node()
def get_node(self):
def get_tree_node(tree, parent, seqlist):
for key, value in tree.items():
if isinstance(value, dict):
pn = Node(str(key), parent)
get_tree_node(value, pn, seqlist)
else:
pn = Node("<" + str(key) + ">", parent)
seqlist.append(pn)
if self.node:
return self.node
assert isinstance(self.tree, dict)
root = Node("root")
get_tree_node(self.tree, root, self.seq_node)
self.node = root.children[0] if len(self.tree) == 1 else root
return self.node
def set_seq_prop(self, index, prop):
self.seq_node[index].prop = prop
def print(self):
treestr = ""
for pre, fill, node in RenderTree(self.get_node()):
treestr += f"{pre}{node.name} {node.prop}\n"
print(treestr)
def print_tree(tree):
def get_tree_node(tree, parent=None):
for key, value in tree.items():
pn = Node(str(key), parent)
if isinstance(value, dict):
get_tree_node(value, pn)
assert isinstance(tree, dict)
root = Node("root")
get_tree_node(tree, root)
treestr = ""
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
treestr += f"{pre}{node.name}\n"
return treestr
# def get_tree_indexed_str(tree, data, prefix):
# if isinstance(tree, list):
# base = ""
# qlen = 0
# for i, t in enumerate(tree):
# s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
# base += s
# qlen += l
# return (base, qlen)
# else:
# if isinstance(tree, dict):
# base = ""
# qlen = 0
# last_is_dict = None
# for key, value in tree.items():
# new_prefix = (len(str(key)) + 2) * " " + prefix
# dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
# if dict_string:
# base += "\n" + prefix + str(key) + ": " + dict_string
# last_is_dict = True
# else:
# base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
# last_is_dict = False
# qlen += l
# return (base, qlen)
# return (None, 1)
# def token_frequency(tree, freq):
# if isinstance(tree, dict):
# for key, value in tree.items():
# if key in freq:
# freq[key] = freq[key] + 1
# else:
# freq[key] = 1
# MeaningMap.token_frequency(value, freq)