Refine query_block_output.
This commit is contained in:
parent
ee30eb4aab
commit
3e6ff2d580
|
@ -1,3 +1,6 @@
|
|||
import pickle
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self):
|
||||
self.vocab_size = 4096
|
||||
|
@ -90,6 +93,17 @@ def class_to_dict(obj):
|
|||
return str(obj)
|
||||
|
||||
|
||||
def class_to_file(obj, file):
|
||||
with open(file, "wb") as file:
|
||||
pickle.dump(obj, file)
|
||||
|
||||
|
||||
def class_from_file(file):
|
||||
with open(file, "rb") as file:
|
||||
obj = pickle.load(file)
|
||||
return obj
|
||||
|
||||
|
||||
# train_config = TrainConfig()
|
||||
# train_config_dict = class_to_dict(train_config)
|
||||
# import pprint
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
import torch
|
||||
|
||||
from model.light_module import LightModule
|
||||
from model.light_module import ModelRunner
|
||||
import numpy as np
|
||||
|
||||
import meaning.dataset as ds
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
||||
# checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt"
|
||||
checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt"
|
||||
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
torch.cuda.manual_seed_all(conf.seed)
|
||||
|
||||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
_, val = ds.InitDataset(conf).dataset
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
|
||||
# seq:844
|
||||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
seq = 995
|
||||
|
||||
node = map.get_nodetree(seq)
|
||||
item, l, rank_idx, rank_all = map.get_sequence(seq)
|
||||
print("len of seq:" + str(len(item)))
|
||||
|
||||
for i in range(1, len(item)):
|
||||
itemm = [item[:i]]
|
||||
batch = torch.tensor([item[:i]], dtype=torch.int64)
|
||||
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||
if item[i] != next_token:
|
||||
node.set_seq_prop(i, "ERR_" + str(next_token))
|
||||
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
||||
node.print()
|
|
@ -2,28 +2,100 @@ import torch
|
|||
|
||||
from model.light_module import LightModule
|
||||
from model.light_module import ModelRunner
|
||||
from model.modeling_wit import QWenLMHeadModel
|
||||
|
||||
import numpy as np
|
||||
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append("..")
|
||||
from tools import show
|
||||
|
||||
import configuration
|
||||
|
||||
import meaning.dataset as ds
|
||||
|
||||
|
||||
def get_latest_file_safe(directory):
|
||||
try:
|
||||
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
|
||||
if not files:
|
||||
print("警告:目录中没有文件")
|
||||
return None
|
||||
latest = max(files, key=lambda f: os.path.getmtime(os.path.join(directory, f)))
|
||||
return latest
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_dataset_set_freq(dataset):
|
||||
loader = dataset
|
||||
|
||||
map = loader.meaning_dataset.get_meaning_map()
|
||||
seqs = {}
|
||||
for batch in loader:
|
||||
for m in batch["meaning"]:
|
||||
seqs[m] = map.get_sequence(m)
|
||||
while True:
|
||||
m = int(input("input meaning: "))
|
||||
total = 0
|
||||
for seq in seqs.values():
|
||||
total = total + seq.count(m)
|
||||
print(f"meaning of {m} count as {total}")
|
||||
|
||||
|
||||
def get_inference(dataset, seq):
|
||||
map = dataset.get_meaning_map()
|
||||
|
||||
node = map.get_nodetree(seq)
|
||||
item, l, rank_idx, rank_all = map.get_sequence(seq)
|
||||
print("len of seq:" + str(len(item)))
|
||||
|
||||
for i in range(1, len(item)):
|
||||
itemm = [item[:i]]
|
||||
batch = torch.tensor([item[:i]], dtype=torch.int64)
|
||||
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||
if item[i] != next_token:
|
||||
node.set_seq_prop(i, "ERR_" + str(next_token))
|
||||
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
||||
node.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt"
|
||||
log_path = "log/bigger/version_1/"
|
||||
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
file = get_latest_file_safe(log_path + "/checkpoints")
|
||||
checkpoint_path = log_path + "checkpoints/" + file
|
||||
conf = configuration.class_from_file(log_path + "conf.pkl")
|
||||
model = QWenLMHeadModel(conf.model_config)
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
|
||||
qwen.eval()
|
||||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
train, val = ds.InitDataset(conf)
|
||||
val = val.dataset
|
||||
# get_dataset_set_freq(train.dataset)
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
|
||||
# seq:844
|
||||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
meaning = 995
|
||||
|
||||
get_inference(md, meaning)
|
||||
|
||||
node = map.get_nodetree(meaning)
|
||||
node.print()
|
||||
|
||||
def DumpQK(query, key, causal_mask, index):
|
||||
global relation_distance
|
||||
size = query.shape[2]
|
||||
|
@ -37,26 +109,13 @@ if __name__ == "__main__":
|
|||
qk = attn_weight[0]
|
||||
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
||||
qk = qk.cpu()
|
||||
qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0)
|
||||
# qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0)
|
||||
show.DumpTensorToImage(qk, prePath)
|
||||
# qk_seq.append(qk)
|
||||
# qk_index = size
|
||||
|
||||
qwen.llm.hook_attention = DumpQK
|
||||
|
||||
_, val = ds.InitDataset(conf).dataset
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
|
||||
# seq:844
|
||||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
meaning = 995
|
||||
|
||||
node = map.get_nodetree(meaning)
|
||||
node.print()
|
||||
|
||||
# current_to_common, common_to_current = map.get_level_change(meaning)
|
||||
# print(current_to_common)
|
||||
# print(common_to_current)
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from model.light_module import LightModule
|
||||
from model.tokenization_qwen import QWenTokenizer
|
||||
import numpy as np
|
||||
|
||||
import configuration
|
||||
import meaning as m
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
||||
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
|
||||
train_dataloader, val_dataloader = m.InitDataset(conf)
|
||||
|
||||
loader = train_dataloader.dataset
|
||||
|
||||
map = loader.meaning_dataset.get_meaning_map()
|
||||
seqs = {}
|
||||
for batch in loader:
|
||||
for m in batch["meaning"]:
|
||||
seqs[m] = map.get_sequence(m)
|
||||
while True:
|
||||
m = int(input("input meaning: "))
|
||||
total = 0
|
||||
for seq in seqs.values():
|
||||
total = total + seq.count(m)
|
||||
print(f"meaning of {m} count as {total}")
|
Loading…
Reference in New Issue