Fix meaning dataset get NodeTree with stride.

This commit is contained in:
Colin 2025-08-21 22:23:28 +08:00
parent e18ee0c781
commit bd5379f24e
1 changed files with 6 additions and 5 deletions

View File

@ -224,18 +224,18 @@ class MeaningMap:
)
def get_nodetree(self, meaning): # return meaning all sub items
def get_tree_node(self, meaning, seq, vs, parent, seqlist, index):
def get_tree_node(self, meaning, seq, level, vs, parent, seqlist, index):
ms = self.ms_map[meaning]
for m in ms[ms >= 0].tolist():
if m >= vs:
pn = NodeTree(str(m), parent)
index = get_tree_node(self, m, seq, vs, pn, seqlist, index)
index = get_tree_node(self, m, seq, level, vs, pn, seqlist, index)
else:
pn = NodeTree("<" + str(index) + "> " + str(m), parent)
index = index + 1
seqlist.append(pn)
while len(seq) > index and seq[index] >= self.normal_vocab:
while len(seq) > index and level[index] >= 255:
pn = NodeTree("<" + str(index) + "> " + str(seq[index]), parent)
index = index + 1
seqlist.append(pn)
@ -245,8 +245,9 @@ class MeaningMap:
seqlist = []
start = self.ms_start[meaning]
seq = self.ms_data[start : start + self.ms_len[meaning]]
level = self.ms_level[start : start + self.ms_len[meaning]]
get_tree_node(self, meaning, seq, self.normal_vocab, root, seqlist, 0)
get_tree_node(self, meaning, seq, level, self.normal_vocab, root, seqlist, 0)
root.seq_node = seqlist
return root
@ -292,7 +293,7 @@ class MeaningMap:
relation[i, start_index : i + 1] = 1.0
return relation
# 根据meaning的层级结构范围一个二的数组表示所有token跟前面每个token的距离
# 根据meaning的层级结构范围一个二的数组表示所有token跟前面每个token的距离
def get_relation_distance(self, meaning):
current_to_common, common_to_current = self.get_level_change(meaning)
width = len(current_to_common)