Refine query.
This commit is contained in:
parent
bd5379f24e
commit
e048d0c531
|
@ -0,0 +1,161 @@
|
||||||
|
len of seq:41
|
||||||
|
index: 1 golden: 6 -> 15 ERR
|
||||||
|
index: 12 golden: 10 -> 22 ERR
|
||||||
|
995
|
||||||
|
├── 202
|
||||||
|
│ ├── <0> 12
|
||||||
|
│ ├── 60
|
||||||
|
│ │ ├── <1> 6 ERR_15
|
||||||
|
│ │ ├── <2> 11
|
||||||
|
│ │ ├── <3> 8
|
||||||
|
│ │ ├── <4> 9
|
||||||
|
│ │ └── <5> 16
|
||||||
|
│ ├── 53
|
||||||
|
│ │ ├── <6> 3
|
||||||
|
│ │ ├── <7> 5
|
||||||
|
│ │ ├── <8> 11
|
||||||
|
│ │ ├── <9> 14
|
||||||
|
│ │ └── <10> 13
|
||||||
|
│ └── <11> 16
|
||||||
|
├── 51
|
||||||
|
│ ├── <12> 10 ERR_22
|
||||||
|
│ ├── <13> 6
|
||||||
|
│ └── <14> 16
|
||||||
|
├── 205
|
||||||
|
│ ├── 76
|
||||||
|
│ │ ├── <15> 0
|
||||||
|
│ │ ├── <16> 25
|
||||||
|
│ │ ├── <17> 3
|
||||||
|
│ │ ├── <18> 6
|
||||||
|
│ │ └── <19> 6
|
||||||
|
│ ├── <20> 19
|
||||||
|
│ ├── <21> 18
|
||||||
|
│ ├── <22> 2
|
||||||
|
│ ├── 55
|
||||||
|
│ │ ├── <23> 14
|
||||||
|
│ │ ├── 31
|
||||||
|
│ │ │ ├── <24> 8
|
||||||
|
│ │ │ ├── <25> 5
|
||||||
|
│ │ │ ├── <26> 5
|
||||||
|
│ │ │ └── <27> 2
|
||||||
|
│ │ ├── <28> 0
|
||||||
|
│ │ ├── <29> 4
|
||||||
|
│ │ ├── <30> 1
|
||||||
|
│ │ └── <31> 1
|
||||||
|
│ └── 31
|
||||||
|
│ ├── <32> 8
|
||||||
|
│ ├── <33> 5
|
||||||
|
│ ├── <34> 5
|
||||||
|
│ └── <35> 2
|
||||||
|
└── 217
|
||||||
|
├── <36> 3
|
||||||
|
└── 31
|
||||||
|
├── <37> 8
|
||||||
|
├── <38> 5
|
||||||
|
├── <39> 5
|
||||||
|
└── <40> 2
|
||||||
|
|
||||||
|
Mapping Load from disk cache: structured_language_50000_32_1_6_2_1_False_42
|
||||||
|
Mapping Load end, elapsed:0.27529025077819824s
|
||||||
|
len of seq:41
|
||||||
|
|
||||||
|
|
||||||
|
index: 2 golden: 6 -> 12 ERR
|
||||||
|
995
|
||||||
|
├── 202
|
||||||
|
│ ├── <0> 12
|
||||||
|
│ ├── <1> 12
|
||||||
|
│ ├── 60
|
||||||
|
│ │ ├── <2> 6 ERR_12
|
||||||
|
│ │ ├── <3> 6
|
||||||
|
│ │ ├── <4> 11
|
||||||
|
│ │ ├── <5> 11
|
||||||
|
│ │ ├── <6> 8
|
||||||
|
│ │ ├── <7> 8
|
||||||
|
│ │ ├── <8> 9
|
||||||
|
│ │ ├── <9> 9
|
||||||
|
│ │ ├── <10> 16
|
||||||
|
│ │ └── <11> 16
|
||||||
|
│ ├── 53
|
||||||
|
│ │ ├── <12> 3
|
||||||
|
│ │ ├── <13> 3
|
||||||
|
│ │ ├── <14> 5
|
||||||
|
│ │ ├── <15> 5
|
||||||
|
│ │ ├── <16> 11
|
||||||
|
│ │ ├── <17> 11
|
||||||
|
│ │ ├── <18> 14
|
||||||
|
│ │ ├── <19> 14
|
||||||
|
│ │ ├── <20> 13
|
||||||
|
│ │ └── <21> 13
|
||||||
|
│ ├── <22> 16
|
||||||
|
│ └── <23> 16
|
||||||
|
├── 51
|
||||||
|
│ ├── <24> 10
|
||||||
|
│ ├── <25> 10
|
||||||
|
│ ├── <26> 6
|
||||||
|
│ ├── <27> 6
|
||||||
|
│ ├── <28> 16
|
||||||
|
│ └── <29> 16
|
||||||
|
├── 205
|
||||||
|
│ ├── 76
|
||||||
|
│ │ ├── <30> 0
|
||||||
|
│ │ ├── <31> 0
|
||||||
|
│ │ ├── <32> 25
|
||||||
|
│ │ ├── <33> 25
|
||||||
|
│ │ ├── <34> 3
|
||||||
|
│ │ ├── <35> 3
|
||||||
|
│ │ ├── <36> 6
|
||||||
|
│ │ ├── <37> 6
|
||||||
|
│ │ ├── <38> 6
|
||||||
|
│ │ └── <39> 6
|
||||||
|
│ ├── <40> 19
|
||||||
|
│ ├── <41> 19
|
||||||
|
│ ├── <42> 18
|
||||||
|
│ ├── <43> 18
|
||||||
|
│ ├── <44> 2
|
||||||
|
│ ├── <45> 2
|
||||||
|
│ ├── 55
|
||||||
|
│ │ ├── <46> 14
|
||||||
|
│ │ ├── <47> 14
|
||||||
|
│ │ ├── 31
|
||||||
|
│ │ │ ├── <48> 8
|
||||||
|
│ │ │ ├── <49> 8
|
||||||
|
│ │ │ ├── <50> 5
|
||||||
|
│ │ │ ├── <51> 5
|
||||||
|
│ │ │ ├── <52> 5
|
||||||
|
│ │ │ ├── <53> 5
|
||||||
|
│ │ │ ├── <54> 2
|
||||||
|
│ │ │ └── <55> 2
|
||||||
|
│ │ ├── <56> 0
|
||||||
|
│ │ ├── <57> 0
|
||||||
|
│ │ ├── <58> 4
|
||||||
|
│ │ ├── <59> 4
|
||||||
|
│ │ ├── <60> 1
|
||||||
|
│ │ ├── <61> 1
|
||||||
|
│ │ ├── <62> 1
|
||||||
|
│ │ └── <63> 1
|
||||||
|
│ └── 31
|
||||||
|
│ ├── <64> 8
|
||||||
|
│ ├── <65> 8
|
||||||
|
│ ├── <66> 5
|
||||||
|
│ ├── <67> 5
|
||||||
|
│ ├── <68> 5
|
||||||
|
│ ├── <69> 5
|
||||||
|
│ ├── <70> 2
|
||||||
|
│ └── <71> 2
|
||||||
|
└── 217
|
||||||
|
├── <72> 3
|
||||||
|
├── <73> 3
|
||||||
|
└── 31
|
||||||
|
├── <74> 8
|
||||||
|
├── <75> 8
|
||||||
|
├── <76> 5
|
||||||
|
├── <77> 5
|
||||||
|
├── <78> 5
|
||||||
|
├── <79> 5
|
||||||
|
├── <80> 2
|
||||||
|
└── <81> 2
|
||||||
|
|
||||||
|
Mapping Load from disk cache: structured_language_50000_32_1_6_2_2_False_42
|
||||||
|
Mapping Load end, elapsed:0.2843344211578369s
|
||||||
|
len of seq:82
|
|
@ -223,6 +223,17 @@ class MeaningMap:
|
||||||
self.ms_rank_all[start : start + len],
|
self.ms_rank_all[start : start + len],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_maps(self, contain=-1):
|
||||||
|
seqlist = []
|
||||||
|
start = self.ms_start[contain]
|
||||||
|
seq = self.ms_data[start : start + self.ms_len[contain]]
|
||||||
|
level = self.ms_level[start : start + self.ms_len[contain]]
|
||||||
|
ms = self.ms_map
|
||||||
|
for m in range(contain, len(ms)):
|
||||||
|
if contain in ms[m]:
|
||||||
|
seqlist.append(ms[m])
|
||||||
|
return seqlist
|
||||||
|
|
||||||
def get_nodetree(self, meaning): # return meaning all sub items
|
def get_nodetree(self, meaning): # return meaning all sub items
|
||||||
def get_tree_node(self, meaning, seq, level, vs, parent, seqlist, index):
|
def get_tree_node(self, meaning, seq, level, vs, parent, seqlist, index):
|
||||||
ms = self.ms_map[meaning]
|
ms = self.ms_map[meaning]
|
||||||
|
|
|
@ -46,11 +46,23 @@ def get_dataset_set_freq(dataset):
|
||||||
print(f"meaning of {m} count as {total}")
|
print(f"meaning of {m} count as {total}")
|
||||||
|
|
||||||
|
|
||||||
def get_inference(dataset, seq):
|
def get_inference_single(dataset, meaning, index):
|
||||||
map = dataset.get_meaning_map()
|
map = dataset.get_meaning_map()
|
||||||
|
|
||||||
node = map.get_nodetree(seq)
|
item, l, rank_idx, rank_all = map.get_sequence(meaning)
|
||||||
item, l, rank_idx, rank_all = map.get_sequence(seq)
|
print("len of seq:" + str(len(item)))
|
||||||
|
|
||||||
|
batch = torch.tensor([item[:index]], dtype=torch.int64)
|
||||||
|
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||||
|
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference(dataset, meaning):
|
||||||
|
map = dataset.get_meaning_map()
|
||||||
|
|
||||||
|
node = map.get_nodetree(meaning)
|
||||||
|
item, l, rank_idx, rank_all = map.get_sequence(meaning)
|
||||||
print("len of seq:" + str(len(item)))
|
print("len of seq:" + str(len(item)))
|
||||||
|
|
||||||
for i in range(1, len(item)):
|
for i in range(1, len(item)):
|
||||||
|
@ -64,38 +76,14 @@ def get_inference(dataset, seq):
|
||||||
node.print()
|
node.print()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def get_contain_meaning(dataset, meaning):
|
||||||
|
map = dataset.get_meaning_map()
|
||||||
|
seqs = map.get_maps(contain=meaning)
|
||||||
|
for seq in seqs:
|
||||||
|
print(str(seq))
|
||||||
|
|
||||||
log_path = "log/bigger/version_1/"
|
|
||||||
|
|
||||||
file = get_latest_file_safe(log_path + "/checkpoints")
|
def dump_dk(dataset, meaning):
|
||||||
checkpoint_path = log_path + "checkpoints/" + file
|
|
||||||
conf = configuration.class_from_file(log_path + "conf.pkl")
|
|
||||||
model = QWenLMHeadModel(conf.model_config)
|
|
||||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
|
|
||||||
qwen.eval()
|
|
||||||
conf = qwen.config
|
|
||||||
torch.manual_seed(conf.seed)
|
|
||||||
np.random.seed(conf.seed)
|
|
||||||
runner = ModelRunner(qwen.llm)
|
|
||||||
|
|
||||||
train, val = ds.InitDataset(conf)
|
|
||||||
val = val.dataset
|
|
||||||
# get_dataset_set_freq(train.dataset)
|
|
||||||
md = val.meaning_dataset
|
|
||||||
map = md.get_meaning_map()
|
|
||||||
|
|
||||||
# seq:844
|
|
||||||
# seq:849
|
|
||||||
# seq:991
|
|
||||||
# seq:995
|
|
||||||
meaning = 991
|
|
||||||
meaning = 10991
|
|
||||||
|
|
||||||
node = map.get_nodetree(meaning)
|
|
||||||
node.print()
|
|
||||||
|
|
||||||
get_inference(md, meaning)
|
|
||||||
|
|
||||||
def DumpQK(query, key, causal_mask, index):
|
def DumpQK(query, key, causal_mask, index):
|
||||||
global relation_distance
|
global relation_distance
|
||||||
|
@ -121,6 +109,8 @@ if __name__ == "__main__":
|
||||||
# print(current_to_common)
|
# print(current_to_common)
|
||||||
# print(common_to_current)
|
# print(common_to_current)
|
||||||
|
|
||||||
|
map = dataset.get_meaning_map()
|
||||||
|
|
||||||
relation_distance = map.get_relation_distance(meaning)
|
relation_distance = map.get_relation_distance(meaning)
|
||||||
relation_distance = torch.tensor(relation_distance)
|
relation_distance = torch.tensor(relation_distance)
|
||||||
rd = relation_distance.unsqueeze(0)
|
rd = relation_distance.unsqueeze(0)
|
||||||
|
@ -139,3 +129,44 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# print(sorted_logits.detach().cpu().numpy())
|
# print(sorted_logits.detach().cpu().numpy())
|
||||||
# print(sorted_indices.detach().cpu().numpy())
|
# print(sorted_indices.detach().cpu().numpy())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
log_path = "log/bigger/version_0/"
|
||||||
|
|
||||||
|
file = get_latest_file_safe(log_path + "/checkpoints")
|
||||||
|
checkpoint_path = log_path + "checkpoints/" + file
|
||||||
|
conf = configuration.class_from_file(log_path + "conf.pkl")
|
||||||
|
model = QWenLMHeadModel(conf.model_config)
|
||||||
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
|
||||||
|
qwen.eval()
|
||||||
|
conf = qwen.config
|
||||||
|
torch.manual_seed(conf.seed)
|
||||||
|
np.random.seed(conf.seed)
|
||||||
|
runner = ModelRunner(qwen.llm)
|
||||||
|
|
||||||
|
train, val = ds.InitDataset(conf)
|
||||||
|
val = val.dataset
|
||||||
|
# get_dataset_set_freq(train.dataset)
|
||||||
|
md = val.meaning_dataset
|
||||||
|
map = md.get_meaning_map()
|
||||||
|
|
||||||
|
# seq:844
|
||||||
|
# seq:849
|
||||||
|
# seq:991
|
||||||
|
# seq:995
|
||||||
|
meaning = 995
|
||||||
|
meaning = 2995
|
||||||
|
|
||||||
|
# node = map.get_nodetree(meaning)
|
||||||
|
# node.print()
|
||||||
|
|
||||||
|
to = get_inference_single(md, meaning, 12)
|
||||||
|
print(to)
|
||||||
|
|
||||||
|
get_inference(md, meaning)
|
||||||
|
|
||||||
|
get_contain_meaning(md, meaning)
|
||||||
|
|
||||||
|
dump_dk(md, meaning)
|
||||||
|
|
Loading…
Reference in New Issue