Fix meaning dataset get NodeTree with stride.
This commit is contained in:
parent
e18ee0c781
commit
bd5379f24e
|
@ -224,18 +224,18 @@ class MeaningMap:
|
|||
)
|
||||
|
||||
def get_nodetree(self, meaning): # return meaning all sub items
|
||||
def get_tree_node(self, meaning, seq, vs, parent, seqlist, index):
|
||||
def get_tree_node(self, meaning, seq, level, vs, parent, seqlist, index):
|
||||
ms = self.ms_map[meaning]
|
||||
for m in ms[ms >= 0].tolist():
|
||||
if m >= vs:
|
||||
pn = NodeTree(str(m), parent)
|
||||
index = get_tree_node(self, m, seq, vs, pn, seqlist, index)
|
||||
index = get_tree_node(self, m, seq, level, vs, pn, seqlist, index)
|
||||
else:
|
||||
pn = NodeTree("<" + str(index) + "> " + str(m), parent)
|
||||
index = index + 1
|
||||
seqlist.append(pn)
|
||||
|
||||
while len(seq) > index and seq[index] >= self.normal_vocab:
|
||||
while len(seq) > index and level[index] >= 255:
|
||||
pn = NodeTree("<" + str(index) + "> " + str(seq[index]), parent)
|
||||
index = index + 1
|
||||
seqlist.append(pn)
|
||||
|
@ -245,8 +245,9 @@ class MeaningMap:
|
|||
seqlist = []
|
||||
start = self.ms_start[meaning]
|
||||
seq = self.ms_data[start : start + self.ms_len[meaning]]
|
||||
level = self.ms_level[start : start + self.ms_len[meaning]]
|
||||
|
||||
get_tree_node(self, meaning, seq, self.normal_vocab, root, seqlist, 0)
|
||||
get_tree_node(self, meaning, seq, level, self.normal_vocab, root, seqlist, 0)
|
||||
root.seq_node = seqlist
|
||||
return root
|
||||
|
||||
|
@ -292,7 +293,7 @@ class MeaningMap:
|
|||
relation[i, start_index : i + 1] = 1.0
|
||||
return relation
|
||||
|
||||
# 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面每个token的距离
|
||||
# 根据meaning的层级结构范围一个二维的数组,表示所有token跟前面每个token的距离
|
||||
def get_relation_distance(self, meaning):
|
||||
current_to_common, common_to_current = self.get_level_change(meaning)
|
||||
width = len(current_to_common)
|
||||
|
|
Loading…
Reference in New Issue