Refine get_nodetree to support stride and tree tokens.

This commit is contained in:
Colin 2025-08-18 11:17:35 +08:00
parent 3e6ff2d580
commit 085bd92fb9
4 changed files with 1064 additions and 1048 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -43,6 +43,7 @@ class MeaningMap:
vocab_of_tree = vocab_size - self.special_vocab
assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special"
self.normal_vocab = vocab_size - self.reserve_vocab
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
@ -223,19 +224,29 @@ class MeaningMap:
)
def get_nodetree(self, meaning): # return meaning all sub items
def get_tree_node(ms_map, meaning, nvs, parent, seqlist):
ms = ms_map[meaning]
def get_tree_node(self, meaning, seq, vs, parent, seqlist, index):
ms = self.ms_map[meaning]
for m in ms[ms >= 0].tolist():
if m >= nvs:
if m >= vs:
pn = NodeTree(str(m), parent)
get_tree_node(ms_map, m, nvs, pn, seqlist)
index = get_tree_node(self, m, seq, vs, pn, seqlist, index)
else:
pn = NodeTree("<" + str(m) + ">", parent)
pn = NodeTree("<" + str(index) + "> " + str(m), parent)
index = index + 1
seqlist.append(pn)
while len(seq) > index and seq[index] >= self.normal_vocab:
pn = NodeTree("<" + str(index) + "> " + str(seq[index]), parent)
index = index + 1
seqlist.append(pn)
return index
root = NodeTree(str(meaning))
seqlist = []
get_tree_node(self.ms_map, meaning, self.normal_vocab, root, seqlist)
start = self.ms_start[meaning]
seq = self.ms_data[start : start + self.ms_len[meaning]]
get_tree_node(self, meaning, seq, self.vocab_size, root, seqlist, 0)
root.seq_node = seqlist
return root
@ -245,7 +256,7 @@ class MeaningMap:
def level_change(ms_map, meaning, current_to_common, common_to_current):
ms = ms_map[meaning]
for m in ms[ms >= 0].tolist():
if m >= self.normal_vocab:
if m >= self.vocab_size:
common_to_current[-1] = common_to_current[-1] + 1
level_change(ms_map, m, current_to_common, common_to_current)
else:

View File

@ -60,13 +60,13 @@ def get_inference(dataset, seq):
next_token = sorted_indices.detach().cpu().numpy()[0][0]
if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR")
print("index: " + str(i) + " golden: " + str(item[i]) + " -> " + str(next_token) + " ERR")
node.print()
if __name__ == "__main__":
log_path = "log/bigger/version_1/"
log_path = "log/bigger/version_2/"
file = get_latest_file_safe(log_path + "/checkpoints")
checkpoint_path = log_path + "checkpoints/" + file
@ -89,13 +89,13 @@ if __name__ == "__main__":
# seq:849
# seq:991
# seq:995
meaning = 995
get_inference(md, meaning)
meaning = 991
node = map.get_nodetree(meaning)
node.print()
get_inference(md, meaning)
def DumpQK(query, key, causal_mask, index):
global relation_distance
size = query.shape[2]