Witllm/wit/dataset/node_tree.py

84 lines
3.0 KiB
Python

from anytree import Node, RenderTree
class NodeTree:
def __init__(self, tree: dict):
self.tree = tree
self.node = None
self.seq_node = []
def get_node(self):
def get_tree_node(tree, parent, seqlist):
for key, value in tree.items():
if isinstance(value, dict):
pn = Node(str(key), parent)
get_tree_node(value, pn, seqlist)
else:
pn = Node("<" + str(key) + ">", parent)
seqlist.append(pn)
if self.node:
return self.node
assert isinstance(self.tree, dict)
root = Node("root")
get_tree_node(self.tree, root, self.seq_node)
self.node = root.children[0] if len(self.tree) == 1 else root
return self.node
def print(self):
treestr = ""
for pre, fill, node in RenderTree(self.get_node()):
treestr += f"{pre}{node.name}\n"
print(treestr)
def print_tree(tree):
def get_tree_node(tree, parent=None):
for key, value in tree.items():
pn = Node(str(key), parent)
if isinstance(value, dict):
get_tree_node(value, pn)
assert isinstance(tree, dict)
root = Node("root")
get_tree_node(tree, root)
treestr = ""
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
treestr += f"{pre}{node.name}\n"
return treestr
# def get_tree_indexed_str(tree, data, prefix):
# if isinstance(tree, list):
# base = ""
# qlen = 0
# for i, t in enumerate(tree):
# s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
# base += s
# qlen += l
# return (base, qlen)
# else:
# if isinstance(tree, dict):
# base = ""
# qlen = 0
# last_is_dict = None
# for key, value in tree.items():
# new_prefix = (len(str(key)) + 2) * " " + prefix
# dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
# if dict_string:
# base += "\n" + prefix + str(key) + ": " + dict_string
# last_is_dict = True
# else:
# base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
# last_is_dict = False
# qlen += l
# return (base, qlen)
# return (None, 1)
# def token_frequency(tree, freq):
# if isinstance(tree, dict):
# for key, value in tree.items():
# if key in freq:
# freq[key] = freq[key] + 1
# else:
# freq[key] = 1
# MeaningMap.token_frequency(value, freq)