Refine query_block_output.
This commit is contained in:
		
							parent
							
								
									ee30eb4aab
								
							
						
					
					
						commit
						3e6ff2d580
					
				|  | @ -1,3 +1,6 @@ | ||||||
|  | import pickle | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class ModelConfig: | class ModelConfig: | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.vocab_size = 4096 |         self.vocab_size = 4096 | ||||||
|  | @ -90,6 +93,17 @@ def class_to_dict(obj): | ||||||
|         return str(obj) |         return str(obj) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def class_to_file(obj, file): | ||||||
|  |     with open(file, "wb") as file: | ||||||
|  |         pickle.dump(obj, file) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def class_from_file(file): | ||||||
|  |     with open(file, "rb") as file: | ||||||
|  |         obj = pickle.load(file) | ||||||
|  |         return obj | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| # train_config = TrainConfig() | # train_config = TrainConfig() | ||||||
| # train_config_dict = class_to_dict(train_config) | # train_config_dict = class_to_dict(train_config) | ||||||
| # import pprint | # import pprint | ||||||
|  |  | ||||||
|  | @ -1,46 +0,0 @@ | ||||||
| import torch |  | ||||||
| 
 |  | ||||||
| from model.light_module import LightModule |  | ||||||
| from model.light_module import ModelRunner |  | ||||||
| import numpy as np |  | ||||||
| 
 |  | ||||||
| import meaning.dataset as ds |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": |  | ||||||
| 
 |  | ||||||
|     # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" |  | ||||||
|     # checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt" |  | ||||||
|     checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt" |  | ||||||
| 
 |  | ||||||
|     qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) |  | ||||||
|     qwen.eval() |  | ||||||
|     conf = qwen.config |  | ||||||
|     torch.manual_seed(conf.seed) |  | ||||||
|     np.random.seed(conf.seed) |  | ||||||
|     torch.cuda.manual_seed_all(conf.seed) |  | ||||||
| 
 |  | ||||||
|     runner = ModelRunner(qwen.llm) |  | ||||||
| 
 |  | ||||||
|     _, val = ds.InitDataset(conf).dataset |  | ||||||
|     md = val.meaning_dataset |  | ||||||
|     map = md.get_meaning_map() |  | ||||||
| 
 |  | ||||||
|     # seq:844 |  | ||||||
|     # seq:849 |  | ||||||
|     # seq:991 |  | ||||||
|     # seq:995 |  | ||||||
|     seq = 995 |  | ||||||
| 
 |  | ||||||
|     node = map.get_nodetree(seq) |  | ||||||
|     item, l, rank_idx, rank_all = map.get_sequence(seq) |  | ||||||
|     print("len of seq:" + str(len(item))) |  | ||||||
| 
 |  | ||||||
|     for i in range(1, len(item)): |  | ||||||
|         itemm = [item[:i]] |  | ||||||
|         batch = torch.tensor([item[:i]], dtype=torch.int64) |  | ||||||
|         sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) |  | ||||||
|         next_token = sorted_indices.detach().cpu().numpy()[0][0] |  | ||||||
|         if item[i] != next_token: |  | ||||||
|             node.set_seq_prop(i, "ERR_" + str(next_token)) |  | ||||||
|             print(str(item[i]) + "  " + str(next_token) + "  ERROR") |  | ||||||
|     node.print() |  | ||||||
|  | @ -2,28 +2,100 @@ import torch | ||||||
| 
 | 
 | ||||||
| from model.light_module import LightModule | from model.light_module import LightModule | ||||||
| from model.light_module import ModelRunner | from model.light_module import ModelRunner | ||||||
|  | from model.modeling_wit import QWenLMHeadModel | ||||||
|  | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| import math | import math | ||||||
| import sys | import sys | ||||||
|  | import os | ||||||
| 
 | 
 | ||||||
| sys.path.append("..") | sys.path.append("..") | ||||||
| from tools import show | from tools import show | ||||||
| 
 | import configuration | ||||||
| 
 | 
 | ||||||
| import meaning.dataset as ds | import meaning.dataset as ds | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | def get_latest_file_safe(directory): | ||||||
|  |     try: | ||||||
|  |         files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))] | ||||||
|  |         if not files: | ||||||
|  |             print("警告:目录中没有文件") | ||||||
|  |             return None | ||||||
|  |         latest = max(files, key=lambda f: os.path.getmtime(os.path.join(directory, f))) | ||||||
|  |         return latest | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"错误: {e}") | ||||||
|  |         return None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_dataset_set_freq(dataset): | ||||||
|  |     loader = dataset | ||||||
|  | 
 | ||||||
|  |     map = loader.meaning_dataset.get_meaning_map() | ||||||
|  |     seqs = {} | ||||||
|  |     for batch in loader: | ||||||
|  |         for m in batch["meaning"]: | ||||||
|  |             seqs[m] = map.get_sequence(m) | ||||||
|  |     while True: | ||||||
|  |         m = int(input("input meaning: ")) | ||||||
|  |         total = 0 | ||||||
|  |         for seq in seqs.values(): | ||||||
|  |             total = total + seq.count(m) | ||||||
|  |         print(f"meaning of {m} count as {total}") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_inference(dataset, seq): | ||||||
|  |     map = dataset.get_meaning_map() | ||||||
|  | 
 | ||||||
|  |     node = map.get_nodetree(seq) | ||||||
|  |     item, l, rank_idx, rank_all = map.get_sequence(seq) | ||||||
|  |     print("len of seq:" + str(len(item))) | ||||||
|  | 
 | ||||||
|  |     for i in range(1, len(item)): | ||||||
|  |         itemm = [item[:i]] | ||||||
|  |         batch = torch.tensor([item[:i]], dtype=torch.int64) | ||||||
|  |         sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) | ||||||
|  |         next_token = sorted_indices.detach().cpu().numpy()[0][0] | ||||||
|  |         if item[i] != next_token: | ||||||
|  |             node.set_seq_prop(i, "ERR_" + str(next_token)) | ||||||
|  |             print(str(item[i]) + "  " + str(next_token) + "  ERROR") | ||||||
|  |     node.print() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
| 
 | 
 | ||||||
|     checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt" |     log_path = "log/bigger/version_1/" | ||||||
| 
 | 
 | ||||||
|     qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) |     file = get_latest_file_safe(log_path + "/checkpoints") | ||||||
|  |     checkpoint_path = log_path + "checkpoints/" + file | ||||||
|  |     conf = configuration.class_from_file(log_path + "conf.pkl") | ||||||
|  |     model = QWenLMHeadModel(conf.model_config) | ||||||
|  |     qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model) | ||||||
|     qwen.eval() |     qwen.eval() | ||||||
|     conf = qwen.config |     conf = qwen.config | ||||||
|     torch.manual_seed(conf.seed) |     torch.manual_seed(conf.seed) | ||||||
|     np.random.seed(conf.seed) |     np.random.seed(conf.seed) | ||||||
|     runner = ModelRunner(qwen.llm) |     runner = ModelRunner(qwen.llm) | ||||||
| 
 | 
 | ||||||
|  |     train, val = ds.InitDataset(conf) | ||||||
|  |     val = val.dataset | ||||||
|  |     # get_dataset_set_freq(train.dataset) | ||||||
|  |     md = val.meaning_dataset | ||||||
|  |     map = md.get_meaning_map() | ||||||
|  | 
 | ||||||
|  |     # seq:844 | ||||||
|  |     # seq:849 | ||||||
|  |     # seq:991 | ||||||
|  |     # seq:995 | ||||||
|  |     meaning = 995 | ||||||
|  | 
 | ||||||
|  |     get_inference(md, meaning) | ||||||
|  | 
 | ||||||
|  |     node = map.get_nodetree(meaning) | ||||||
|  |     node.print() | ||||||
|  | 
 | ||||||
|     def DumpQK(query, key, causal_mask, index): |     def DumpQK(query, key, causal_mask, index): | ||||||
|         global relation_distance |         global relation_distance | ||||||
|         size = query.shape[2] |         size = query.shape[2] | ||||||
|  | @ -37,26 +109,13 @@ if __name__ == "__main__": | ||||||
|         qk = attn_weight[0] |         qk = attn_weight[0] | ||||||
|         prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" |         prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" | ||||||
|         qk = qk.cpu() |         qk = qk.cpu() | ||||||
|         qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0) |         # qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0) | ||||||
|         show.DumpTensorToImage(qk, prePath) |         show.DumpTensorToImage(qk, prePath) | ||||||
|         # qk_seq.append(qk) |         # qk_seq.append(qk) | ||||||
|         # qk_index = size |         # qk_index = size | ||||||
| 
 | 
 | ||||||
|     qwen.llm.hook_attention = DumpQK |     qwen.llm.hook_attention = DumpQK | ||||||
| 
 | 
 | ||||||
|     _, val = ds.InitDataset(conf).dataset |  | ||||||
|     md = val.meaning_dataset |  | ||||||
|     map = md.get_meaning_map() |  | ||||||
| 
 |  | ||||||
|     # seq:844 |  | ||||||
|     # seq:849 |  | ||||||
|     # seq:991 |  | ||||||
|     # seq:995 |  | ||||||
|     meaning = 995 |  | ||||||
| 
 |  | ||||||
|     node = map.get_nodetree(meaning) |  | ||||||
|     node.print() |  | ||||||
| 
 |  | ||||||
|     # current_to_common, common_to_current = map.get_level_change(meaning) |     # current_to_common, common_to_current = map.get_level_change(meaning) | ||||||
|     # print(current_to_common) |     # print(current_to_common) | ||||||
|     # print(common_to_current) |     # print(common_to_current) | ||||||
|  |  | ||||||
|  | @ -1,35 +0,0 @@ | ||||||
| import pytorch_lightning as pl |  | ||||||
| import torch |  | ||||||
| 
 |  | ||||||
| from model.light_module import LightModule |  | ||||||
| from model.tokenization_qwen import QWenTokenizer |  | ||||||
| import numpy as np |  | ||||||
| 
 |  | ||||||
| import configuration |  | ||||||
| import meaning as m |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": |  | ||||||
| 
 |  | ||||||
|     checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" |  | ||||||
| 
 |  | ||||||
|     qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) |  | ||||||
|     qwen.eval() |  | ||||||
|     conf = qwen.config |  | ||||||
|     torch.manual_seed(conf.seed) |  | ||||||
|     np.random.seed(conf.seed) |  | ||||||
| 
 |  | ||||||
|     train_dataloader, val_dataloader = m.InitDataset(conf) |  | ||||||
| 
 |  | ||||||
|     loader = train_dataloader.dataset |  | ||||||
| 
 |  | ||||||
|     map = loader.meaning_dataset.get_meaning_map() |  | ||||||
|     seqs = {} |  | ||||||
|     for batch in loader: |  | ||||||
|         for m in batch["meaning"]: |  | ||||||
|             seqs[m] = map.get_sequence(m) |  | ||||||
|     while True: |  | ||||||
|         m = int(input("input meaning: ")) |  | ||||||
|         total = 0 |  | ||||||
|         for seq in seqs.values(): |  | ||||||
|             total = total + seq.count(m) |  | ||||||
|         print(f"meaning of {m} count as {total}") |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Colin
						Colin