142 lines
4.4 KiB
Python
142 lines
4.4 KiB
Python
import torch
|
|
|
|
from model.light_module import LightModule
|
|
from model.light_module import ModelRunner
|
|
from model.modeling_wit import QWenLMHeadModel
|
|
|
|
import numpy as np
|
|
|
|
import math
|
|
import sys
|
|
import os
|
|
|
|
sys.path.append("..")
|
|
from tools import show
|
|
import configuration
|
|
|
|
import meaning.dataset as ds
|
|
|
|
|
|
def get_latest_file_safe(directory):
|
|
try:
|
|
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
|
|
if not files:
|
|
print("警告:目录中没有文件")
|
|
return None
|
|
latest = max(files, key=lambda f: os.path.getmtime(os.path.join(directory, f)))
|
|
return latest
|
|
except Exception as e:
|
|
print(f"错误: {e}")
|
|
return None
|
|
|
|
|
|
def get_dataset_set_freq(dataset):
|
|
loader = dataset
|
|
|
|
map = loader.meaning_dataset.get_meaning_map()
|
|
seqs = {}
|
|
for batch in loader:
|
|
for m in batch["meaning"]:
|
|
seqs[m] = map.get_sequence(m)
|
|
while True:
|
|
m = int(input("input meaning: "))
|
|
total = 0
|
|
for seq in seqs.values():
|
|
total = total + seq.count(m)
|
|
print(f"meaning of {m} count as {total}")
|
|
|
|
|
|
def get_inference(dataset, seq):
|
|
map = dataset.get_meaning_map()
|
|
|
|
node = map.get_nodetree(seq)
|
|
item, l, rank_idx, rank_all = map.get_sequence(seq)
|
|
print("len of seq:" + str(len(item)))
|
|
|
|
for i in range(1, len(item)):
|
|
itemm = [item[:i]]
|
|
batch = torch.tensor([item[:i]], dtype=torch.int64)
|
|
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
|
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
|
if item[i] != next_token:
|
|
node.set_seq_prop(i, "ERR_" + str(next_token))
|
|
print("index: " + str(i) + " golden: " + str(item[i]) + " -> " + str(next_token) + " ERR")
|
|
node.print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
log_path = "log/bigger/version_1/"
|
|
|
|
file = get_latest_file_safe(log_path + "/checkpoints")
|
|
checkpoint_path = log_path + "checkpoints/" + file
|
|
conf = configuration.class_from_file(log_path + "conf.pkl")
|
|
model = QWenLMHeadModel(conf.model_config)
|
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
|
|
qwen.eval()
|
|
conf = qwen.config
|
|
torch.manual_seed(conf.seed)
|
|
np.random.seed(conf.seed)
|
|
runner = ModelRunner(qwen.llm)
|
|
|
|
train, val = ds.InitDataset(conf)
|
|
val = val.dataset
|
|
# get_dataset_set_freq(train.dataset)
|
|
md = val.meaning_dataset
|
|
map = md.get_meaning_map()
|
|
|
|
# seq:844
|
|
# seq:849
|
|
# seq:991
|
|
# seq:995
|
|
meaning = 991
|
|
meaning = 10991
|
|
|
|
node = map.get_nodetree(meaning)
|
|
node.print()
|
|
|
|
get_inference(md, meaning)
|
|
|
|
def DumpQK(query, key, causal_mask, index):
|
|
global relation_distance
|
|
size = query.shape[2]
|
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
|
|
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
|
attn_weight = attn_weight * attn_mask
|
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
attn_weight = attn_weight * attn_mask
|
|
qk = attn_weight[0]
|
|
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
|
qk = qk.cpu()
|
|
# qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0)
|
|
show.DumpTensorToImage(qk, prePath)
|
|
# qk_seq.append(qk)
|
|
# qk_index = size
|
|
|
|
qwen.llm.hook_attention = DumpQK
|
|
|
|
# current_to_common, common_to_current = map.get_level_change(meaning)
|
|
# print(current_to_common)
|
|
# print(common_to_current)
|
|
|
|
relation_distance = map.get_relation_distance(meaning)
|
|
relation_distance = torch.tensor(relation_distance)
|
|
rd = relation_distance.unsqueeze(0)
|
|
rd = torch.cat((rd, rd, rd, rd, rd, rd), dim=0)
|
|
show.DumpTensorToImage(rd, "./temp/q@k_seq_layer.png")
|
|
|
|
item, level, rank_idx, rank_all = map.get_sequence(meaning)
|
|
print("len of seq:" + str(len(item)))
|
|
|
|
batch = torch.tensor([item], dtype=torch.int64)
|
|
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
|
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
|
|
|
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
|
|
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
|
|
|
# print(sorted_logits.detach().cpu().numpy())
|
|
# print(sorted_indices.detach().cpu().numpy())
|