Refine query.

This commit is contained in:
Colin 2025-08-26 13:31:03 +08:00
parent bd5379f24e
commit e048d0c531
3 changed files with 236 additions and 33 deletions

161
wit/995 Normal file
View File

@ -0,0 +1,161 @@
len of seq:41
index: 1 golden: 6 -> 15 ERR
index: 12 golden: 10 -> 22 ERR
995
├── 202
│ ├── <0> 12
│ ├── 60
│ │ ├── <1> 6 ERR_15
│ │ ├── <2> 11
│ │ ├── <3> 8
│ │ ├── <4> 9
│ │ └── <5> 16
│ ├── 53
│ │ ├── <6> 3
│ │ ├── <7> 5
│ │ ├── <8> 11
│ │ ├── <9> 14
│ │ └── <10> 13
│ └── <11> 16
├── 51
│ ├── <12> 10 ERR_22
│ ├── <13> 6
│ └── <14> 16
├── 205
│ ├── 76
│ │ ├── <15> 0
│ │ ├── <16> 25
│ │ ├── <17> 3
│ │ ├── <18> 6
│ │ └── <19> 6
│ ├── <20> 19
│ ├── <21> 18
│ ├── <22> 2
│ ├── 55
│ │ ├── <23> 14
│ │ ├── 31
│ │ │ ├── <24> 8
│ │ │ ├── <25> 5
│ │ │ ├── <26> 5
│ │ │ └── <27> 2
│ │ ├── <28> 0
│ │ ├── <29> 4
│ │ ├── <30> 1
│ │ └── <31> 1
│ └── 31
│ ├── <32> 8
│ ├── <33> 5
│ ├── <34> 5
│ └── <35> 2
└── 217
├── <36> 3
└── 31
├── <37> 8
├── <38> 5
├── <39> 5
└── <40> 2
Mapping Load from disk cache: structured_language_50000_32_1_6_2_1_False_42
Mapping Load end, elapsed:0.27529025077819824s
len of seq:41
index: 2 golden: 6 -> 12 ERR
995
├── 202
│ ├── <0> 12
│ ├── <1> 12
│ ├── 60
│ │ ├── <2> 6 ERR_12
│ │ ├── <3> 6
│ │ ├── <4> 11
│ │ ├── <5> 11
│ │ ├── <6> 8
│ │ ├── <7> 8
│ │ ├── <8> 9
│ │ ├── <9> 9
│ │ ├── <10> 16
│ │ └── <11> 16
│ ├── 53
│ │ ├── <12> 3
│ │ ├── <13> 3
│ │ ├── <14> 5
│ │ ├── <15> 5
│ │ ├── <16> 11
│ │ ├── <17> 11
│ │ ├── <18> 14
│ │ ├── <19> 14
│ │ ├── <20> 13
│ │ └── <21> 13
│ ├── <22> 16
│ └── <23> 16
├── 51
│ ├── <24> 10
│ ├── <25> 10
│ ├── <26> 6
│ ├── <27> 6
│ ├── <28> 16
│ └── <29> 16
├── 205
│ ├── 76
│ │ ├── <30> 0
│ │ ├── <31> 0
│ │ ├── <32> 25
│ │ ├── <33> 25
│ │ ├── <34> 3
│ │ ├── <35> 3
│ │ ├── <36> 6
│ │ ├── <37> 6
│ │ ├── <38> 6
│ │ └── <39> 6
│ ├── <40> 19
│ ├── <41> 19
│ ├── <42> 18
│ ├── <43> 18
│ ├── <44> 2
│ ├── <45> 2
│ ├── 55
│ │ ├── <46> 14
│ │ ├── <47> 14
│ │ ├── 31
│ │ │ ├── <48> 8
│ │ │ ├── <49> 8
│ │ │ ├── <50> 5
│ │ │ ├── <51> 5
│ │ │ ├── <52> 5
│ │ │ ├── <53> 5
│ │ │ ├── <54> 2
│ │ │ └── <55> 2
│ │ ├── <56> 0
│ │ ├── <57> 0
│ │ ├── <58> 4
│ │ ├── <59> 4
│ │ ├── <60> 1
│ │ ├── <61> 1
│ │ ├── <62> 1
│ │ └── <63> 1
│ └── 31
│ ├── <64> 8
│ ├── <65> 8
│ ├── <66> 5
│ ├── <67> 5
│ ├── <68> 5
│ ├── <69> 5
│ ├── <70> 2
│ └── <71> 2
└── 217
├── <72> 3
├── <73> 3
└── 31
├── <74> 8
├── <75> 8
├── <76> 5
├── <77> 5
├── <78> 5
├── <79> 5
├── <80> 2
└── <81> 2
Mapping Load from disk cache: structured_language_50000_32_1_6_2_2_False_42
Mapping Load end, elapsed:0.2843344211578369s
len of seq:82

View File

@ -223,6 +223,17 @@ class MeaningMap:
self.ms_rank_all[start : start + len],
)
def get_maps(self, contain=-1):
seqlist = []
start = self.ms_start[contain]
seq = self.ms_data[start : start + self.ms_len[contain]]
level = self.ms_level[start : start + self.ms_len[contain]]
ms = self.ms_map
for m in range(contain, len(ms)):
if contain in ms[m]:
seqlist.append(ms[m])
return seqlist
def get_nodetree(self, meaning): # return meaning all sub items
def get_tree_node(self, meaning, seq, level, vs, parent, seqlist, index):
ms = self.ms_map[meaning]

View File

@ -46,11 +46,23 @@ def get_dataset_set_freq(dataset):
print(f"meaning of {m} count as {total}")
def get_inference(dataset, seq):
def get_inference_single(dataset, meaning, index):
map = dataset.get_meaning_map()
node = map.get_nodetree(seq)
item, l, rank_idx, rank_all = map.get_sequence(seq)
item, l, rank_idx, rank_all = map.get_sequence(meaning)
print("len of seq:" + str(len(item)))
batch = torch.tensor([item[:index]], dtype=torch.int64)
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
next_token = sorted_indices.detach().cpu().numpy()[0][0]
return next_token
def get_inference(dataset, meaning):
map = dataset.get_meaning_map()
node = map.get_nodetree(meaning)
item, l, rank_idx, rank_all = map.get_sequence(meaning)
print("len of seq:" + str(len(item)))
for i in range(1, len(item)):
@ -64,38 +76,14 @@ def get_inference(dataset, seq):
node.print()
if __name__ == "__main__":
def get_contain_meaning(dataset, meaning):
map = dataset.get_meaning_map()
seqs = map.get_maps(contain=meaning)
for seq in seqs:
print(str(seq))
log_path = "log/bigger/version_1/"
file = get_latest_file_safe(log_path + "/checkpoints")
checkpoint_path = log_path + "checkpoints/" + file
conf = configuration.class_from_file(log_path + "conf.pkl")
model = QWenLMHeadModel(conf.model_config)
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
runner = ModelRunner(qwen.llm)
train, val = ds.InitDataset(conf)
val = val.dataset
# get_dataset_set_freq(train.dataset)
md = val.meaning_dataset
map = md.get_meaning_map()
# seq:844
# seq:849
# seq:991
# seq:995
meaning = 991
meaning = 10991
node = map.get_nodetree(meaning)
node.print()
get_inference(md, meaning)
def dump_dk(dataset, meaning):
def DumpQK(query, key, causal_mask, index):
global relation_distance
@ -121,6 +109,8 @@ if __name__ == "__main__":
# print(current_to_common)
# print(common_to_current)
map = dataset.get_meaning_map()
relation_distance = map.get_relation_distance(meaning)
relation_distance = torch.tensor(relation_distance)
rd = relation_distance.unsqueeze(0)
@ -139,3 +129,44 @@ if __name__ == "__main__":
# print(sorted_logits.detach().cpu().numpy())
# print(sorted_indices.detach().cpu().numpy())
if __name__ == "__main__":
log_path = "log/bigger/version_0/"
file = get_latest_file_safe(log_path + "/checkpoints")
checkpoint_path = log_path + "checkpoints/" + file
conf = configuration.class_from_file(log_path + "conf.pkl")
model = QWenLMHeadModel(conf.model_config)
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
runner = ModelRunner(qwen.llm)
train, val = ds.InitDataset(conf)
val = val.dataset
# get_dataset_set_freq(train.dataset)
md = val.meaning_dataset
map = md.get_meaning_map()
# seq:844
# seq:849
# seq:991
# seq:995
meaning = 995
meaning = 2995
# node = map.get_nodetree(meaning)
# node.print()
to = get_inference_single(md, meaning, 12)
print(to)
get_inference(md, meaning)
get_contain_meaning(md, meaning)
dump_dk(md, meaning)