Refine query.
This commit is contained in:
		
							parent
							
								
									bd5379f24e
								
							
						
					
					
						commit
						e048d0c531
					
				| 
						 | 
				
			
			@ -0,0 +1,161 @@
 | 
			
		|||
len of seq:41
 | 
			
		||||
index: 1 golden: 6 -> 15  ERR
 | 
			
		||||
index: 12 golden: 10 -> 22  ERR
 | 
			
		||||
995 
 | 
			
		||||
├── 202 
 | 
			
		||||
│   ├── <0> 12 
 | 
			
		||||
│   ├── 60 
 | 
			
		||||
│   │   ├── <1> 6 ERR_15
 | 
			
		||||
│   │   ├── <2> 11 
 | 
			
		||||
│   │   ├── <3> 8 
 | 
			
		||||
│   │   ├── <4> 9 
 | 
			
		||||
│   │   └── <5> 16 
 | 
			
		||||
│   ├── 53 
 | 
			
		||||
│   │   ├── <6> 3 
 | 
			
		||||
│   │   ├── <7> 5 
 | 
			
		||||
│   │   ├── <8> 11 
 | 
			
		||||
│   │   ├── <9> 14 
 | 
			
		||||
│   │   └── <10> 13 
 | 
			
		||||
│   └── <11> 16 
 | 
			
		||||
├── 51 
 | 
			
		||||
│   ├── <12> 10 ERR_22
 | 
			
		||||
│   ├── <13> 6 
 | 
			
		||||
│   └── <14> 16 
 | 
			
		||||
├── 205 
 | 
			
		||||
│   ├── 76 
 | 
			
		||||
│   │   ├── <15> 0 
 | 
			
		||||
│   │   ├── <16> 25 
 | 
			
		||||
│   │   ├── <17> 3 
 | 
			
		||||
│   │   ├── <18> 6 
 | 
			
		||||
│   │   └── <19> 6 
 | 
			
		||||
│   ├── <20> 19 
 | 
			
		||||
│   ├── <21> 18 
 | 
			
		||||
│   ├── <22> 2 
 | 
			
		||||
│   ├── 55 
 | 
			
		||||
│   │   ├── <23> 14 
 | 
			
		||||
│   │   ├── 31 
 | 
			
		||||
│   │   │   ├── <24> 8 
 | 
			
		||||
│   │   │   ├── <25> 5 
 | 
			
		||||
│   │   │   ├── <26> 5 
 | 
			
		||||
│   │   │   └── <27> 2 
 | 
			
		||||
│   │   ├── <28> 0 
 | 
			
		||||
│   │   ├── <29> 4 
 | 
			
		||||
│   │   ├── <30> 1 
 | 
			
		||||
│   │   └── <31> 1 
 | 
			
		||||
│   └── 31 
 | 
			
		||||
│       ├── <32> 8 
 | 
			
		||||
│       ├── <33> 5 
 | 
			
		||||
│       ├── <34> 5 
 | 
			
		||||
│       └── <35> 2 
 | 
			
		||||
└── 217 
 | 
			
		||||
    ├── <36> 3 
 | 
			
		||||
    └── 31 
 | 
			
		||||
        ├── <37> 8 
 | 
			
		||||
        ├── <38> 5 
 | 
			
		||||
        ├── <39> 5 
 | 
			
		||||
        └── <40> 2 
 | 
			
		||||
 | 
			
		||||
Mapping Load from disk cache: structured_language_50000_32_1_6_2_1_False_42
 | 
			
		||||
Mapping Load end, elapsed:0.27529025077819824s
 | 
			
		||||
len of seq:41
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
index: 2 golden: 6 -> 12  ERR
 | 
			
		||||
995 
 | 
			
		||||
├── 202 
 | 
			
		||||
│   ├── <0> 12 
 | 
			
		||||
│   ├── <1> 12 
 | 
			
		||||
│   ├── 60 
 | 
			
		||||
│   │   ├── <2> 6 ERR_12
 | 
			
		||||
│   │   ├── <3> 6 
 | 
			
		||||
│   │   ├── <4> 11 
 | 
			
		||||
│   │   ├── <5> 11 
 | 
			
		||||
│   │   ├── <6> 8 
 | 
			
		||||
│   │   ├── <7> 8 
 | 
			
		||||
│   │   ├── <8> 9 
 | 
			
		||||
│   │   ├── <9> 9 
 | 
			
		||||
│   │   ├── <10> 16 
 | 
			
		||||
│   │   └── <11> 16 
 | 
			
		||||
│   ├── 53 
 | 
			
		||||
│   │   ├── <12> 3 
 | 
			
		||||
│   │   ├── <13> 3 
 | 
			
		||||
│   │   ├── <14> 5 
 | 
			
		||||
│   │   ├── <15> 5 
 | 
			
		||||
│   │   ├── <16> 11 
 | 
			
		||||
│   │   ├── <17> 11 
 | 
			
		||||
│   │   ├── <18> 14 
 | 
			
		||||
│   │   ├── <19> 14 
 | 
			
		||||
│   │   ├── <20> 13 
 | 
			
		||||
│   │   └── <21> 13 
 | 
			
		||||
│   ├── <22> 16 
 | 
			
		||||
│   └── <23> 16 
 | 
			
		||||
├── 51 
 | 
			
		||||
│   ├── <24> 10 
 | 
			
		||||
│   ├── <25> 10 
 | 
			
		||||
│   ├── <26> 6 
 | 
			
		||||
│   ├── <27> 6 
 | 
			
		||||
│   ├── <28> 16 
 | 
			
		||||
│   └── <29> 16 
 | 
			
		||||
├── 205 
 | 
			
		||||
│   ├── 76 
 | 
			
		||||
│   │   ├── <30> 0 
 | 
			
		||||
│   │   ├── <31> 0 
 | 
			
		||||
│   │   ├── <32> 25 
 | 
			
		||||
│   │   ├── <33> 25 
 | 
			
		||||
│   │   ├── <34> 3 
 | 
			
		||||
│   │   ├── <35> 3 
 | 
			
		||||
│   │   ├── <36> 6 
 | 
			
		||||
│   │   ├── <37> 6 
 | 
			
		||||
│   │   ├── <38> 6 
 | 
			
		||||
│   │   └── <39> 6 
 | 
			
		||||
│   ├── <40> 19 
 | 
			
		||||
│   ├── <41> 19 
 | 
			
		||||
│   ├── <42> 18 
 | 
			
		||||
│   ├── <43> 18 
 | 
			
		||||
│   ├── <44> 2 
 | 
			
		||||
│   ├── <45> 2 
 | 
			
		||||
│   ├── 55 
 | 
			
		||||
│   │   ├── <46> 14 
 | 
			
		||||
│   │   ├── <47> 14 
 | 
			
		||||
│   │   ├── 31 
 | 
			
		||||
│   │   │   ├── <48> 8 
 | 
			
		||||
│   │   │   ├── <49> 8 
 | 
			
		||||
│   │   │   ├── <50> 5 
 | 
			
		||||
│   │   │   ├── <51> 5 
 | 
			
		||||
│   │   │   ├── <52> 5 
 | 
			
		||||
│   │   │   ├── <53> 5 
 | 
			
		||||
│   │   │   ├── <54> 2 
 | 
			
		||||
│   │   │   └── <55> 2 
 | 
			
		||||
│   │   ├── <56> 0 
 | 
			
		||||
│   │   ├── <57> 0 
 | 
			
		||||
│   │   ├── <58> 4 
 | 
			
		||||
│   │   ├── <59> 4 
 | 
			
		||||
│   │   ├── <60> 1 
 | 
			
		||||
│   │   ├── <61> 1 
 | 
			
		||||
│   │   ├── <62> 1 
 | 
			
		||||
│   │   └── <63> 1 
 | 
			
		||||
│   └── 31 
 | 
			
		||||
│       ├── <64> 8 
 | 
			
		||||
│       ├── <65> 8 
 | 
			
		||||
│       ├── <66> 5 
 | 
			
		||||
│       ├── <67> 5 
 | 
			
		||||
│       ├── <68> 5 
 | 
			
		||||
│       ├── <69> 5 
 | 
			
		||||
│       ├── <70> 2 
 | 
			
		||||
│       └── <71> 2 
 | 
			
		||||
└── 217 
 | 
			
		||||
    ├── <72> 3 
 | 
			
		||||
    ├── <73> 3 
 | 
			
		||||
    └── 31 
 | 
			
		||||
        ├── <74> 8 
 | 
			
		||||
        ├── <75> 8 
 | 
			
		||||
        ├── <76> 5 
 | 
			
		||||
        ├── <77> 5 
 | 
			
		||||
        ├── <78> 5 
 | 
			
		||||
        ├── <79> 5 
 | 
			
		||||
        ├── <80> 2 
 | 
			
		||||
        └── <81> 2 
 | 
			
		||||
 | 
			
		||||
Mapping Load from disk cache: structured_language_50000_32_1_6_2_2_False_42
 | 
			
		||||
Mapping Load end, elapsed:0.2843344211578369s
 | 
			
		||||
len of seq:82
 | 
			
		||||
| 
						 | 
				
			
			@ -223,6 +223,17 @@ class MeaningMap:
 | 
			
		|||
            self.ms_rank_all[start : start + len],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_maps(self, contain=-1):
 | 
			
		||||
        seqlist = []
 | 
			
		||||
        start = self.ms_start[contain]
 | 
			
		||||
        seq = self.ms_data[start : start + self.ms_len[contain]]
 | 
			
		||||
        level = self.ms_level[start : start + self.ms_len[contain]]
 | 
			
		||||
        ms = self.ms_map
 | 
			
		||||
        for m in range(contain, len(ms)):
 | 
			
		||||
            if contain in ms[m]:
 | 
			
		||||
                seqlist.append(ms[m])
 | 
			
		||||
        return seqlist
 | 
			
		||||
 | 
			
		||||
    def get_nodetree(self, meaning):  # return meaning all sub items
 | 
			
		||||
        def get_tree_node(self, meaning, seq, level, vs, parent, seqlist, index):
 | 
			
		||||
            ms = self.ms_map[meaning]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,11 +46,23 @@ def get_dataset_set_freq(dataset):
 | 
			
		|||
        print(f"meaning of {m} count as {total}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_inference(dataset, seq):
 | 
			
		||||
def get_inference_single(dataset, meaning, index):
 | 
			
		||||
    map = dataset.get_meaning_map()
 | 
			
		||||
 | 
			
		||||
    node = map.get_nodetree(seq)
 | 
			
		||||
    item, l, rank_idx, rank_all = map.get_sequence(seq)
 | 
			
		||||
    item, l, rank_idx, rank_all = map.get_sequence(meaning)
 | 
			
		||||
    print("len of seq:" + str(len(item)))
 | 
			
		||||
 | 
			
		||||
    batch = torch.tensor([item[:index]], dtype=torch.int64)
 | 
			
		||||
    sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
 | 
			
		||||
    next_token = sorted_indices.detach().cpu().numpy()[0][0]
 | 
			
		||||
    return next_token
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_inference(dataset, meaning):
 | 
			
		||||
    map = dataset.get_meaning_map()
 | 
			
		||||
 | 
			
		||||
    node = map.get_nodetree(meaning)
 | 
			
		||||
    item, l, rank_idx, rank_all = map.get_sequence(meaning)
 | 
			
		||||
    print("len of seq:" + str(len(item)))
 | 
			
		||||
 | 
			
		||||
    for i in range(1, len(item)):
 | 
			
		||||
| 
						 | 
				
			
			@ -64,38 +76,14 @@ def get_inference(dataset, seq):
 | 
			
		|||
    node.print()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
def get_contain_meaning(dataset, meaning):
 | 
			
		||||
    map = dataset.get_meaning_map()
 | 
			
		||||
    seqs = map.get_maps(contain=meaning)
 | 
			
		||||
    for seq in seqs:
 | 
			
		||||
        print(str(seq))
 | 
			
		||||
 | 
			
		||||
    log_path = "log/bigger/version_1/"
 | 
			
		||||
 | 
			
		||||
    file = get_latest_file_safe(log_path + "/checkpoints")
 | 
			
		||||
    checkpoint_path = log_path + "checkpoints/" + file
 | 
			
		||||
    conf = configuration.class_from_file(log_path + "conf.pkl")
 | 
			
		||||
    model = QWenLMHeadModel(conf.model_config)
 | 
			
		||||
    qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
 | 
			
		||||
    qwen.eval()
 | 
			
		||||
    conf = qwen.config
 | 
			
		||||
    torch.manual_seed(conf.seed)
 | 
			
		||||
    np.random.seed(conf.seed)
 | 
			
		||||
    runner = ModelRunner(qwen.llm)
 | 
			
		||||
 | 
			
		||||
    train, val = ds.InitDataset(conf)
 | 
			
		||||
    val = val.dataset
 | 
			
		||||
    # get_dataset_set_freq(train.dataset)
 | 
			
		||||
    md = val.meaning_dataset
 | 
			
		||||
    map = md.get_meaning_map()
 | 
			
		||||
 | 
			
		||||
    # seq:844
 | 
			
		||||
    # seq:849
 | 
			
		||||
    # seq:991
 | 
			
		||||
    # seq:995
 | 
			
		||||
    meaning = 991
 | 
			
		||||
    meaning = 10991
 | 
			
		||||
 | 
			
		||||
    node = map.get_nodetree(meaning)
 | 
			
		||||
    node.print()
 | 
			
		||||
 | 
			
		||||
    get_inference(md, meaning)
 | 
			
		||||
def dump_dk(dataset, meaning):
 | 
			
		||||
 | 
			
		||||
    def DumpQK(query, key, causal_mask, index):
 | 
			
		||||
        global relation_distance
 | 
			
		||||
| 
						 | 
				
			
			@ -121,6 +109,8 @@ if __name__ == "__main__":
 | 
			
		|||
    # print(current_to_common)
 | 
			
		||||
    # print(common_to_current)
 | 
			
		||||
 | 
			
		||||
    map = dataset.get_meaning_map()
 | 
			
		||||
 | 
			
		||||
    relation_distance = map.get_relation_distance(meaning)
 | 
			
		||||
    relation_distance = torch.tensor(relation_distance)
 | 
			
		||||
    rd = relation_distance.unsqueeze(0)
 | 
			
		||||
| 
						 | 
				
			
			@ -139,3 +129,44 @@ if __name__ == "__main__":
 | 
			
		|||
 | 
			
		||||
    # print(sorted_logits.detach().cpu().numpy())
 | 
			
		||||
    # print(sorted_indices.detach().cpu().numpy())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    log_path = "log/bigger/version_0/"
 | 
			
		||||
 | 
			
		||||
    file = get_latest_file_safe(log_path + "/checkpoints")
 | 
			
		||||
    checkpoint_path = log_path + "checkpoints/" + file
 | 
			
		||||
    conf = configuration.class_from_file(log_path + "conf.pkl")
 | 
			
		||||
    model = QWenLMHeadModel(conf.model_config)
 | 
			
		||||
    qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
 | 
			
		||||
    qwen.eval()
 | 
			
		||||
    conf = qwen.config
 | 
			
		||||
    torch.manual_seed(conf.seed)
 | 
			
		||||
    np.random.seed(conf.seed)
 | 
			
		||||
    runner = ModelRunner(qwen.llm)
 | 
			
		||||
 | 
			
		||||
    train, val = ds.InitDataset(conf)
 | 
			
		||||
    val = val.dataset
 | 
			
		||||
    # get_dataset_set_freq(train.dataset)
 | 
			
		||||
    md = val.meaning_dataset
 | 
			
		||||
    map = md.get_meaning_map()
 | 
			
		||||
 | 
			
		||||
    # seq:844
 | 
			
		||||
    # seq:849
 | 
			
		||||
    # seq:991
 | 
			
		||||
    # seq:995
 | 
			
		||||
    meaning = 995
 | 
			
		||||
    meaning = 2995
 | 
			
		||||
 | 
			
		||||
    # node = map.get_nodetree(meaning)
 | 
			
		||||
    # node.print()
 | 
			
		||||
 | 
			
		||||
    to = get_inference_single(md, meaning, 12)
 | 
			
		||||
    print(to)
 | 
			
		||||
 | 
			
		||||
    get_inference(md, meaning)
 | 
			
		||||
 | 
			
		||||
    get_contain_meaning(md, meaning)
 | 
			
		||||
 | 
			
		||||
    dump_dk(md, meaning)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue