Refine query_block_output.

This commit is contained in:
Colin 2025-08-17 17:58:23 +08:00
parent ee30eb4aab
commit 3e6ff2d580
4 changed files with 90 additions and 98 deletions

View File

@ -1,3 +1,6 @@
import pickle
class ModelConfig: class ModelConfig:
def __init__(self): def __init__(self):
self.vocab_size = 4096 self.vocab_size = 4096
@ -90,6 +93,17 @@ def class_to_dict(obj):
return str(obj) return str(obj)
def class_to_file(obj, file):
with open(file, "wb") as file:
pickle.dump(obj, file)
def class_from_file(file):
with open(file, "rb") as file:
obj = pickle.load(file)
return obj
# train_config = TrainConfig() # train_config = TrainConfig()
# train_config_dict = class_to_dict(train_config) # train_config_dict = class_to_dict(train_config)
# import pprint # import pprint

View File

@ -1,46 +0,0 @@
import torch
from model.light_module import LightModule
from model.light_module import ModelRunner
import numpy as np
import meaning.dataset as ds
if __name__ == "__main__":
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
# checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt"
checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt"
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
torch.cuda.manual_seed_all(conf.seed)
runner = ModelRunner(qwen.llm)
_, val = ds.InitDataset(conf).dataset
md = val.meaning_dataset
map = md.get_meaning_map()
# seq:844
# seq:849
# seq:991
# seq:995
seq = 995
node = map.get_nodetree(seq)
item, l, rank_idx, rank_all = map.get_sequence(seq)
print("len of seq:" + str(len(item)))
for i in range(1, len(item)):
itemm = [item[:i]]
batch = torch.tensor([item[:i]], dtype=torch.int64)
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
next_token = sorted_indices.detach().cpu().numpy()[0][0]
if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR")
node.print()

View File

@ -2,28 +2,100 @@ import torch
from model.light_module import LightModule from model.light_module import LightModule
from model.light_module import ModelRunner from model.light_module import ModelRunner
from model.modeling_wit import QWenLMHeadModel
import numpy as np import numpy as np
import math import math
import sys import sys
import os
sys.path.append("..") sys.path.append("..")
from tools import show from tools import show
import configuration
import meaning.dataset as ds import meaning.dataset as ds
def get_latest_file_safe(directory):
try:
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
if not files:
print("警告:目录中没有文件")
return None
latest = max(files, key=lambda f: os.path.getmtime(os.path.join(directory, f)))
return latest
except Exception as e:
print(f"错误: {e}")
return None
def get_dataset_set_freq(dataset):
loader = dataset
map = loader.meaning_dataset.get_meaning_map()
seqs = {}
for batch in loader:
for m in batch["meaning"]:
seqs[m] = map.get_sequence(m)
while True:
m = int(input("input meaning: "))
total = 0
for seq in seqs.values():
total = total + seq.count(m)
print(f"meaning of {m} count as {total}")
def get_inference(dataset, seq):
map = dataset.get_meaning_map()
node = map.get_nodetree(seq)
item, l, rank_idx, rank_all = map.get_sequence(seq)
print("len of seq:" + str(len(item)))
for i in range(1, len(item)):
itemm = [item[:i]]
batch = torch.tensor([item[:i]], dtype=torch.int64)
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
next_token = sorted_indices.detach().cpu().numpy()[0][0]
if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR")
node.print()
if __name__ == "__main__": if __name__ == "__main__":
checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt" log_path = "log/bigger/version_1/"
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) file = get_latest_file_safe(log_path + "/checkpoints")
checkpoint_path = log_path + "checkpoints/" + file
conf = configuration.class_from_file(log_path + "conf.pkl")
model = QWenLMHeadModel(conf.model_config)
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
qwen.eval() qwen.eval()
conf = qwen.config conf = qwen.config
torch.manual_seed(conf.seed) torch.manual_seed(conf.seed)
np.random.seed(conf.seed) np.random.seed(conf.seed)
runner = ModelRunner(qwen.llm) runner = ModelRunner(qwen.llm)
train, val = ds.InitDataset(conf)
val = val.dataset
# get_dataset_set_freq(train.dataset)
md = val.meaning_dataset
map = md.get_meaning_map()
# seq:844
# seq:849
# seq:991
# seq:995
meaning = 995
get_inference(md, meaning)
node = map.get_nodetree(meaning)
node.print()
def DumpQK(query, key, causal_mask, index): def DumpQK(query, key, causal_mask, index):
global relation_distance global relation_distance
size = query.shape[2] size = query.shape[2]
@ -37,26 +109,13 @@ if __name__ == "__main__":
qk = attn_weight[0] qk = attn_weight[0]
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
qk = qk.cpu() qk = qk.cpu()
qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0) # qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0)
show.DumpTensorToImage(qk, prePath) show.DumpTensorToImage(qk, prePath)
# qk_seq.append(qk) # qk_seq.append(qk)
# qk_index = size # qk_index = size
qwen.llm.hook_attention = DumpQK qwen.llm.hook_attention = DumpQK
_, val = ds.InitDataset(conf).dataset
md = val.meaning_dataset
map = md.get_meaning_map()
# seq:844
# seq:849
# seq:991
# seq:995
meaning = 995
node = map.get_nodetree(meaning)
node.print()
# current_to_common, common_to_current = map.get_level_change(meaning) # current_to_common, common_to_current = map.get_level_change(meaning)
# print(current_to_common) # print(current_to_common)
# print(common_to_current) # print(common_to_current)

View File

@ -1,35 +0,0 @@
import pytorch_lightning as pl
import torch
from model.light_module import LightModule
from model.tokenization_qwen import QWenTokenizer
import numpy as np
import configuration
import meaning as m
if __name__ == "__main__":
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
train_dataloader, val_dataloader = m.InitDataset(conf)
loader = train_dataloader.dataset
map = loader.meaning_dataset.get_meaning_map()
seqs = {}
for batch in loader:
for m in batch["meaning"]:
seqs[m] = map.get_sequence(m)
while True:
m = int(input("input meaning: "))
total = 0
for seq in seqs.values():
total = total + seq.count(m)
print(f"meaning of {m} count as {total}")