diff --git a/wit/meaning/meaning_dataset.py b/wit/meaning/meaning_dataset.py index d46bbf9..1d3fcc1 100644 --- a/wit/meaning/meaning_dataset.py +++ b/wit/meaning/meaning_dataset.py @@ -224,18 +224,18 @@ class MeaningMap: ) def get_nodetree(self, meaning): # return meaning all sub items - def get_tree_node(self, meaning, seq, vs, parent, seqlist, index): + def get_tree_node(self, meaning, seq, level, vs, parent, seqlist, index): ms = self.ms_map[meaning] for m in ms[ms >= 0].tolist(): if m >= vs: pn = NodeTree(str(m), parent) - index = get_tree_node(self, m, seq, vs, pn, seqlist, index) + index = get_tree_node(self, m, seq, level, vs, pn, seqlist, index) else: pn = NodeTree("<" + str(index) + "> " + str(m), parent) index = index + 1 seqlist.append(pn) - while len(seq) > index and seq[index] >= self.normal_vocab: + while len(seq) > index and level[index] >= 255: pn = NodeTree("<" + str(index) + "> " + str(seq[index]), parent) index = index + 1 seqlist.append(pn) @@ -245,8 +245,9 @@ class MeaningMap: seqlist = [] start = self.ms_start[meaning] seq = self.ms_data[start : start + self.ms_len[meaning]] + level = self.ms_level[start : start + self.ms_len[meaning]] - get_tree_node(self, meaning, seq, self.normal_vocab, root, seqlist, 0) + get_tree_node(self, meaning, seq, level, self.normal_vocab, root, seqlist, 0) root.seq_node = seqlist return root @@ -292,7 +293,7 @@ class MeaningMap: relation[i, start_index : i + 1] = 1.0 return relation - # 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面每个token的距离 + # 根据meaning的层级结构范围一个二维的数组,表示所有token跟前面每个token的距离 def get_relation_distance(self, meaning): current_to_common, common_to_current = self.get_level_change(meaning) width = len(current_to_common)