Refine get_nodetree to support stride and tree tokens.

This commit is contained in:
Colin 2025-08-18 11:17:35 +08:00
parent 3e6ff2d580
commit 085bd92fb9
4 changed files with 1064 additions and 1048 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -43,6 +43,7 @@ class MeaningMap:
vocab_of_tree = vocab_size - self.special_vocab vocab_of_tree = vocab_size - self.special_vocab
assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special" assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special"
self.normal_vocab = vocab_size - self.reserve_vocab self.normal_vocab = vocab_size - self.reserve_vocab
self.vocab_size = vocab_size
self.max_subitem = max_subitem self.max_subitem = max_subitem
self.min_subitem = min_subitem self.min_subitem = min_subitem
@ -223,19 +224,29 @@ class MeaningMap:
) )
def get_nodetree(self, meaning): # return meaning all sub items def get_nodetree(self, meaning): # return meaning all sub items
def get_tree_node(ms_map, meaning, nvs, parent, seqlist): def get_tree_node(self, meaning, seq, vs, parent, seqlist, index):
ms = ms_map[meaning] ms = self.ms_map[meaning]
for m in ms[ms >= 0].tolist(): for m in ms[ms >= 0].tolist():
if m >= nvs: if m >= vs:
pn = NodeTree(str(m), parent) pn = NodeTree(str(m), parent)
get_tree_node(ms_map, m, nvs, pn, seqlist) index = get_tree_node(self, m, seq, vs, pn, seqlist, index)
else: else:
pn = NodeTree("<" + str(m) + ">", parent) pn = NodeTree("<" + str(index) + "> " + str(m), parent)
index = index + 1
seqlist.append(pn) seqlist.append(pn)
while len(seq) > index and seq[index] >= self.normal_vocab:
pn = NodeTree("<" + str(index) + "> " + str(seq[index]), parent)
index = index + 1
seqlist.append(pn)
return index
root = NodeTree(str(meaning)) root = NodeTree(str(meaning))
seqlist = [] seqlist = []
get_tree_node(self.ms_map, meaning, self.normal_vocab, root, seqlist) start = self.ms_start[meaning]
seq = self.ms_data[start : start + self.ms_len[meaning]]
get_tree_node(self, meaning, seq, self.vocab_size, root, seqlist, 0)
root.seq_node = seqlist root.seq_node = seqlist
return root return root
@ -245,7 +256,7 @@ class MeaningMap:
def level_change(ms_map, meaning, current_to_common, common_to_current): def level_change(ms_map, meaning, current_to_common, common_to_current):
ms = ms_map[meaning] ms = ms_map[meaning]
for m in ms[ms >= 0].tolist(): for m in ms[ms >= 0].tolist():
if m >= self.normal_vocab: if m >= self.vocab_size:
common_to_current[-1] = common_to_current[-1] + 1 common_to_current[-1] = common_to_current[-1] + 1
level_change(ms_map, m, current_to_common, common_to_current) level_change(ms_map, m, current_to_common, common_to_current)
else: else:

View File

@ -60,13 +60,13 @@ def get_inference(dataset, seq):
next_token = sorted_indices.detach().cpu().numpy()[0][0] next_token = sorted_indices.detach().cpu().numpy()[0][0]
if item[i] != next_token: if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token)) node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR") print("index: " + str(i) + " golden: " + str(item[i]) + " -> " + str(next_token) + " ERR")
node.print() node.print()
if __name__ == "__main__": if __name__ == "__main__":
log_path = "log/bigger/version_1/" log_path = "log/bigger/version_2/"
file = get_latest_file_safe(log_path + "/checkpoints") file = get_latest_file_safe(log_path + "/checkpoints")
checkpoint_path = log_path + "checkpoints/" + file checkpoint_path = log_path + "checkpoints/" + file
@ -89,13 +89,13 @@ if __name__ == "__main__":
# seq:849 # seq:849
# seq:991 # seq:991
# seq:995 # seq:995
meaning = 995 meaning = 991
get_inference(md, meaning)
node = map.get_nodetree(meaning) node = map.get_nodetree(meaning)
node.print() node.print()
get_inference(md, meaning)
def DumpQK(query, key, causal_mask, index): def DumpQK(query, key, causal_mask, index):
global relation_distance global relation_distance
size = query.shape[2] size = query.shape[2]