From e048d0c531d67852ac8698e40a8cbfac2797de01 Mon Sep 17 00:00:00 2001 From: Colin <> Date: Tue, 26 Aug 2025 13:31:03 +0800 Subject: [PATCH] Refine query. --- wit/995 | 161 +++++++++++++++++++++++++++++++++ wit/meaning/meaning_dataset.py | 11 +++ wit/query_block_output.py | 97 +++++++++++++------- 3 files changed, 236 insertions(+), 33 deletions(-) create mode 100644 wit/995 diff --git a/wit/995 b/wit/995 new file mode 100644 index 0000000..0a3a62d --- /dev/null +++ b/wit/995 @@ -0,0 +1,161 @@ +len of seq:41 +index: 1 golden: 6 -> 15 ERR +index: 12 golden: 10 -> 22 ERR +995 +├── 202 +│ ├── <0> 12 +│ ├── 60 +│ │ ├── <1> 6 ERR_15 +│ │ ├── <2> 11 +│ │ ├── <3> 8 +│ │ ├── <4> 9 +│ │ └── <5> 16 +│ ├── 53 +│ │ ├── <6> 3 +│ │ ├── <7> 5 +│ │ ├── <8> 11 +│ │ ├── <9> 14 +│ │ └── <10> 13 +│ └── <11> 16 +├── 51 +│ ├── <12> 10 ERR_22 +│ ├── <13> 6 +│ └── <14> 16 +├── 205 +│ ├── 76 +│ │ ├── <15> 0 +│ │ ├── <16> 25 +│ │ ├── <17> 3 +│ │ ├── <18> 6 +│ │ └── <19> 6 +│ ├── <20> 19 +│ ├── <21> 18 +│ ├── <22> 2 +│ ├── 55 +│ │ ├── <23> 14 +│ │ ├── 31 +│ │ │ ├── <24> 8 +│ │ │ ├── <25> 5 +│ │ │ ├── <26> 5 +│ │ │ └── <27> 2 +│ │ ├── <28> 0 +│ │ ├── <29> 4 +│ │ ├── <30> 1 +│ │ └── <31> 1 +│ └── 31 +│ ├── <32> 8 +│ ├── <33> 5 +│ ├── <34> 5 +│ └── <35> 2 +└── 217 + ├── <36> 3 + └── 31 + ├── <37> 8 + ├── <38> 5 + ├── <39> 5 + └── <40> 2 + +Mapping Load from disk cache: structured_language_50000_32_1_6_2_1_False_42 +Mapping Load end, elapsed:0.27529025077819824s +len of seq:41 + + +index: 2 golden: 6 -> 12 ERR +995 +├── 202 +│ ├── <0> 12 +│ ├── <1> 12 +│ ├── 60 +│ │ ├── <2> 6 ERR_12 +│ │ ├── <3> 6 +│ │ ├── <4> 11 +│ │ ├── <5> 11 +│ │ ├── <6> 8 +│ │ ├── <7> 8 +│ │ ├── <8> 9 +│ │ ├── <9> 9 +│ │ ├── <10> 16 +│ │ └── <11> 16 +│ ├── 53 +│ │ ├── <12> 3 +│ │ ├── <13> 3 +│ │ ├── <14> 5 +│ │ ├── <15> 5 +│ │ ├── <16> 11 +│ │ ├── <17> 11 +│ │ ├── <18> 14 +│ │ ├── <19> 14 +│ │ ├── <20> 13 +│ │ └── <21> 13 +│ ├── <22> 16 +│ └── <23> 16 +├── 51 +│ ├── <24> 10 +│ ├── <25> 10 +│ ├── <26> 6 +│ ├── <27> 6 +│ ├── <28> 16 +│ └── <29> 16 +├── 205 +│ ├── 76 +│ │ ├── <30> 0 +│ │ ├── <31> 0 +│ │ ├── <32> 25 +│ │ ├── <33> 25 +│ │ ├── <34> 3 +│ │ ├── <35> 3 +│ │ ├── <36> 6 +│ │ ├── <37> 6 +│ │ ├── <38> 6 +│ │ └── <39> 6 +│ ├── <40> 19 +│ ├── <41> 19 +│ ├── <42> 18 +│ ├── <43> 18 +│ ├── <44> 2 +│ ├── <45> 2 +│ ├── 55 +│ │ ├── <46> 14 +│ │ ├── <47> 14 +│ │ ├── 31 +│ │ │ ├── <48> 8 +│ │ │ ├── <49> 8 +│ │ │ ├── <50> 5 +│ │ │ ├── <51> 5 +│ │ │ ├── <52> 5 +│ │ │ ├── <53> 5 +│ │ │ ├── <54> 2 +│ │ │ └── <55> 2 +│ │ ├── <56> 0 +│ │ ├── <57> 0 +│ │ ├── <58> 4 +│ │ ├── <59> 4 +│ │ ├── <60> 1 +│ │ ├── <61> 1 +│ │ ├── <62> 1 +│ │ └── <63> 1 +│ └── 31 +│ ├── <64> 8 +│ ├── <65> 8 +│ ├── <66> 5 +│ ├── <67> 5 +│ ├── <68> 5 +│ ├── <69> 5 +│ ├── <70> 2 +│ └── <71> 2 +└── 217 + ├── <72> 3 + ├── <73> 3 + └── 31 + ├── <74> 8 + ├── <75> 8 + ├── <76> 5 + ├── <77> 5 + ├── <78> 5 + ├── <79> 5 + ├── <80> 2 + └── <81> 2 + +Mapping Load from disk cache: structured_language_50000_32_1_6_2_2_False_42 +Mapping Load end, elapsed:0.2843344211578369s +len of seq:82 \ No newline at end of file diff --git a/wit/meaning/meaning_dataset.py b/wit/meaning/meaning_dataset.py index 1d3fcc1..a6960a2 100644 --- a/wit/meaning/meaning_dataset.py +++ b/wit/meaning/meaning_dataset.py @@ -223,6 +223,17 @@ class MeaningMap: self.ms_rank_all[start : start + len], ) + def get_maps(self, contain=-1): + seqlist = [] + start = self.ms_start[contain] + seq = self.ms_data[start : start + self.ms_len[contain]] + level = self.ms_level[start : start + self.ms_len[contain]] + ms = self.ms_map + for m in range(contain, len(ms)): + if contain in ms[m]: + seqlist.append(ms[m]) + return seqlist + def get_nodetree(self, meaning): # return meaning all sub items def get_tree_node(self, meaning, seq, level, vs, parent, seqlist, index): ms = self.ms_map[meaning] diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 6f964d7..ab8886f 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -46,11 +46,23 @@ def get_dataset_set_freq(dataset): print(f"meaning of {m} count as {total}") -def get_inference(dataset, seq): +def get_inference_single(dataset, meaning, index): map = dataset.get_meaning_map() - node = map.get_nodetree(seq) - item, l, rank_idx, rank_all = map.get_sequence(seq) + item, l, rank_idx, rank_all = map.get_sequence(meaning) + print("len of seq:" + str(len(item))) + + batch = torch.tensor([item[:index]], dtype=torch.int64) + sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) + next_token = sorted_indices.detach().cpu().numpy()[0][0] + return next_token + + +def get_inference(dataset, meaning): + map = dataset.get_meaning_map() + + node = map.get_nodetree(meaning) + item, l, rank_idx, rank_all = map.get_sequence(meaning) print("len of seq:" + str(len(item))) for i in range(1, len(item)): @@ -64,38 +76,14 @@ def get_inference(dataset, seq): node.print() -if __name__ == "__main__": +def get_contain_meaning(dataset, meaning): + map = dataset.get_meaning_map() + seqs = map.get_maps(contain=meaning) + for seq in seqs: + print(str(seq)) - log_path = "log/bigger/version_1/" - file = get_latest_file_safe(log_path + "/checkpoints") - checkpoint_path = log_path + "checkpoints/" + file - conf = configuration.class_from_file(log_path + "conf.pkl") - model = QWenLMHeadModel(conf.model_config) - qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model) - qwen.eval() - conf = qwen.config - torch.manual_seed(conf.seed) - np.random.seed(conf.seed) - runner = ModelRunner(qwen.llm) - - train, val = ds.InitDataset(conf) - val = val.dataset - # get_dataset_set_freq(train.dataset) - md = val.meaning_dataset - map = md.get_meaning_map() - - # seq:844 - # seq:849 - # seq:991 - # seq:995 - meaning = 991 - meaning = 10991 - - node = map.get_nodetree(meaning) - node.print() - - get_inference(md, meaning) +def dump_dk(dataset, meaning): def DumpQK(query, key, causal_mask, index): global relation_distance @@ -121,6 +109,8 @@ if __name__ == "__main__": # print(current_to_common) # print(common_to_current) + map = dataset.get_meaning_map() + relation_distance = map.get_relation_distance(meaning) relation_distance = torch.tensor(relation_distance) rd = relation_distance.unsqueeze(0) @@ -139,3 +129,44 @@ if __name__ == "__main__": # print(sorted_logits.detach().cpu().numpy()) # print(sorted_indices.detach().cpu().numpy()) + + +if __name__ == "__main__": + + log_path = "log/bigger/version_0/" + + file = get_latest_file_safe(log_path + "/checkpoints") + checkpoint_path = log_path + "checkpoints/" + file + conf = configuration.class_from_file(log_path + "conf.pkl") + model = QWenLMHeadModel(conf.model_config) + qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model) + qwen.eval() + conf = qwen.config + torch.manual_seed(conf.seed) + np.random.seed(conf.seed) + runner = ModelRunner(qwen.llm) + + train, val = ds.InitDataset(conf) + val = val.dataset + # get_dataset_set_freq(train.dataset) + md = val.meaning_dataset + map = md.get_meaning_map() + + # seq:844 + # seq:849 + # seq:991 + # seq:995 + meaning = 995 + meaning = 2995 + + # node = map.get_nodetree(meaning) + # node.print() + + to = get_inference_single(md, meaning, 12) + print(to) + + get_inference(md, meaning) + + get_contain_meaning(md, meaning) + + dump_dk(md, meaning)