Refine get_nodetree to support stride and tree tokens.
This commit is contained in:
parent
3e6ff2d580
commit
085bd92fb9
1036
wit/Untitled-1.ini
1036
wit/Untitled-1.ini
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -43,6 +43,7 @@ class MeaningMap:
|
|||
vocab_of_tree = vocab_size - self.special_vocab
|
||||
assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special"
|
||||
self.normal_vocab = vocab_size - self.reserve_vocab
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.max_subitem = max_subitem
|
||||
self.min_subitem = min_subitem
|
||||
|
@ -223,19 +224,29 @@ class MeaningMap:
|
|||
)
|
||||
|
||||
def get_nodetree(self, meaning): # return meaning all sub items
|
||||
def get_tree_node(ms_map, meaning, nvs, parent, seqlist):
|
||||
ms = ms_map[meaning]
|
||||
def get_tree_node(self, meaning, seq, vs, parent, seqlist, index):
|
||||
ms = self.ms_map[meaning]
|
||||
for m in ms[ms >= 0].tolist():
|
||||
if m >= nvs:
|
||||
if m >= vs:
|
||||
pn = NodeTree(str(m), parent)
|
||||
get_tree_node(ms_map, m, nvs, pn, seqlist)
|
||||
index = get_tree_node(self, m, seq, vs, pn, seqlist, index)
|
||||
else:
|
||||
pn = NodeTree("<" + str(m) + ">", parent)
|
||||
pn = NodeTree("<" + str(index) + "> " + str(m), parent)
|
||||
index = index + 1
|
||||
seqlist.append(pn)
|
||||
|
||||
while len(seq) > index and seq[index] >= self.normal_vocab:
|
||||
pn = NodeTree("<" + str(index) + "> " + str(seq[index]), parent)
|
||||
index = index + 1
|
||||
seqlist.append(pn)
|
||||
return index
|
||||
|
||||
root = NodeTree(str(meaning))
|
||||
seqlist = []
|
||||
get_tree_node(self.ms_map, meaning, self.normal_vocab, root, seqlist)
|
||||
start = self.ms_start[meaning]
|
||||
seq = self.ms_data[start : start + self.ms_len[meaning]]
|
||||
|
||||
get_tree_node(self, meaning, seq, self.vocab_size, root, seqlist, 0)
|
||||
root.seq_node = seqlist
|
||||
return root
|
||||
|
||||
|
@ -245,7 +256,7 @@ class MeaningMap:
|
|||
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
||||
ms = ms_map[meaning]
|
||||
for m in ms[ms >= 0].tolist():
|
||||
if m >= self.normal_vocab:
|
||||
if m >= self.vocab_size:
|
||||
common_to_current[-1] = common_to_current[-1] + 1
|
||||
level_change(ms_map, m, current_to_common, common_to_current)
|
||||
else:
|
||||
|
|
|
@ -60,13 +60,13 @@ def get_inference(dataset, seq):
|
|||
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||
if item[i] != next_token:
|
||||
node.set_seq_prop(i, "ERR_" + str(next_token))
|
||||
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
||||
print("index: " + str(i) + " golden: " + str(item[i]) + " -> " + str(next_token) + " ERR")
|
||||
node.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
log_path = "log/bigger/version_1/"
|
||||
log_path = "log/bigger/version_2/"
|
||||
|
||||
file = get_latest_file_safe(log_path + "/checkpoints")
|
||||
checkpoint_path = log_path + "checkpoints/" + file
|
||||
|
@ -89,13 +89,13 @@ if __name__ == "__main__":
|
|||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
meaning = 995
|
||||
|
||||
get_inference(md, meaning)
|
||||
meaning = 991
|
||||
|
||||
node = map.get_nodetree(meaning)
|
||||
node.print()
|
||||
|
||||
get_inference(md, meaning)
|
||||
|
||||
def DumpQK(query, key, causal_mask, index):
|
||||
global relation_distance
|
||||
size = query.shape[2]
|
||||
|
|
Loading…
Reference in New Issue