Refine get_nodetree to support stride and tree tokens.
This commit is contained in:
parent
3e6ff2d580
commit
085bd92fb9
1036
wit/Untitled-1.ini
1036
wit/Untitled-1.ini
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -43,6 +43,7 @@ class MeaningMap:
|
||||||
vocab_of_tree = vocab_size - self.special_vocab
|
vocab_of_tree = vocab_size - self.special_vocab
|
||||||
assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special"
|
assert reserve_vocab >= self.special_vocab, "must reserve enough vocab for special"
|
||||||
self.normal_vocab = vocab_size - self.reserve_vocab
|
self.normal_vocab = vocab_size - self.reserve_vocab
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
self.min_subitem = min_subitem
|
self.min_subitem = min_subitem
|
||||||
|
@ -223,19 +224,29 @@ class MeaningMap:
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_nodetree(self, meaning): # return meaning all sub items
|
def get_nodetree(self, meaning): # return meaning all sub items
|
||||||
def get_tree_node(ms_map, meaning, nvs, parent, seqlist):
|
def get_tree_node(self, meaning, seq, vs, parent, seqlist, index):
|
||||||
ms = ms_map[meaning]
|
ms = self.ms_map[meaning]
|
||||||
for m in ms[ms >= 0].tolist():
|
for m in ms[ms >= 0].tolist():
|
||||||
if m >= nvs:
|
if m >= vs:
|
||||||
pn = NodeTree(str(m), parent)
|
pn = NodeTree(str(m), parent)
|
||||||
get_tree_node(ms_map, m, nvs, pn, seqlist)
|
index = get_tree_node(self, m, seq, vs, pn, seqlist, index)
|
||||||
else:
|
else:
|
||||||
pn = NodeTree("<" + str(m) + ">", parent)
|
pn = NodeTree("<" + str(index) + "> " + str(m), parent)
|
||||||
|
index = index + 1
|
||||||
seqlist.append(pn)
|
seqlist.append(pn)
|
||||||
|
|
||||||
|
while len(seq) > index and seq[index] >= self.normal_vocab:
|
||||||
|
pn = NodeTree("<" + str(index) + "> " + str(seq[index]), parent)
|
||||||
|
index = index + 1
|
||||||
|
seqlist.append(pn)
|
||||||
|
return index
|
||||||
|
|
||||||
root = NodeTree(str(meaning))
|
root = NodeTree(str(meaning))
|
||||||
seqlist = []
|
seqlist = []
|
||||||
get_tree_node(self.ms_map, meaning, self.normal_vocab, root, seqlist)
|
start = self.ms_start[meaning]
|
||||||
|
seq = self.ms_data[start : start + self.ms_len[meaning]]
|
||||||
|
|
||||||
|
get_tree_node(self, meaning, seq, self.vocab_size, root, seqlist, 0)
|
||||||
root.seq_node = seqlist
|
root.seq_node = seqlist
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
@ -245,7 +256,7 @@ class MeaningMap:
|
||||||
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
||||||
ms = ms_map[meaning]
|
ms = ms_map[meaning]
|
||||||
for m in ms[ms >= 0].tolist():
|
for m in ms[ms >= 0].tolist():
|
||||||
if m >= self.normal_vocab:
|
if m >= self.vocab_size:
|
||||||
common_to_current[-1] = common_to_current[-1] + 1
|
common_to_current[-1] = common_to_current[-1] + 1
|
||||||
level_change(ms_map, m, current_to_common, common_to_current)
|
level_change(ms_map, m, current_to_common, common_to_current)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -60,13 +60,13 @@ def get_inference(dataset, seq):
|
||||||
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||||
if item[i] != next_token:
|
if item[i] != next_token:
|
||||||
node.set_seq_prop(i, "ERR_" + str(next_token))
|
node.set_seq_prop(i, "ERR_" + str(next_token))
|
||||||
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
print("index: " + str(i) + " golden: " + str(item[i]) + " -> " + str(next_token) + " ERR")
|
||||||
node.print()
|
node.print()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
log_path = "log/bigger/version_1/"
|
log_path = "log/bigger/version_2/"
|
||||||
|
|
||||||
file = get_latest_file_safe(log_path + "/checkpoints")
|
file = get_latest_file_safe(log_path + "/checkpoints")
|
||||||
checkpoint_path = log_path + "checkpoints/" + file
|
checkpoint_path = log_path + "checkpoints/" + file
|
||||||
|
@ -89,13 +89,13 @@ if __name__ == "__main__":
|
||||||
# seq:849
|
# seq:849
|
||||||
# seq:991
|
# seq:991
|
||||||
# seq:995
|
# seq:995
|
||||||
meaning = 995
|
meaning = 991
|
||||||
|
|
||||||
get_inference(md, meaning)
|
|
||||||
|
|
||||||
node = map.get_nodetree(meaning)
|
node = map.get_nodetree(meaning)
|
||||||
node.print()
|
node.print()
|
||||||
|
|
||||||
|
get_inference(md, meaning)
|
||||||
|
|
||||||
def DumpQK(query, key, causal_mask, index):
|
def DumpQK(query, key, causal_mask, index):
|
||||||
global relation_distance
|
global relation_distance
|
||||||
size = query.shape[2]
|
size = query.shape[2]
|
||||||
|
|
Loading…
Reference in New Issue