Witllm/wit/query_block_output.py

83 lines
2.6 KiB
Python
Raw Normal View History

2025-03-03 15:23:39 +08:00
import torch
2025-03-18 15:58:08 +08:00
from model.light_module import LightModule
from model.light_module import ModelRunner
2025-03-03 15:23:39 +08:00
import numpy as np
import math
import sys
sys.path.append("..")
from tools import show
import dataset.dataset as ds
if __name__ == "__main__":
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
2025-03-03 15:23:39 +08:00
2025-03-10 19:14:47 +08:00
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
2025-03-03 15:23:39 +08:00
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
runner = ModelRunner(qwen.llm)
def DumpQK(query, key, causal_mask, index):
global relation_table
2025-03-03 15:23:39 +08:00
size = query.shape[2]
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
attn_weight = attn_weight * attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = attn_weight * attn_mask
qk = attn_weight[0]
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
qk = qk.cpu()
qk = torch.cat((qk, relation_table.unsqueeze(0)), dim=0)
2025-03-03 15:23:39 +08:00
show.DumpTensorToImage(qk, prePath, GridValue=255)
# qk_seq.append(qk)
# qk_index = size
qwen.llm.hook_attention = DumpQK
2025-05-22 15:26:43 +08:00
val = ds.InitValDataset(conf).dataset
md = val.meaning_dataset
map = md.get_meaning_map()
# seq:844
# seq:849
# seq:991
# seq:995
meaning = 995
node = map.get_nodetree(meaning)
current_to_common, common_to_current = map.get_level_change(meaning)
node.print()
print(current_to_common)
print(common_to_current)
relation_table = map.get_relation_table(meaning)
# prePath = "./temp/" + "q@k_seq_" + "_layer_" + ".png"
# show.DumpTensorToImage(relation_table, prePath, GridValue=255)
relation_table = torch.tensor(relation_table)
item, level, rank_idx, rank_all = map.get_sequence(meaning)
print(item)
print(level)
print(rank_idx)
print(rank_all)
print("len of seq:" + str(len(item)))
2025-05-22 15:26:43 +08:00
batch = torch.tensor([item], dtype=torch.int64)
2025-03-03 15:23:39 +08:00
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
2025-05-22 15:26:43 +08:00
next_token = sorted_indices.detach().cpu().numpy()[0][0]
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
2025-03-03 15:23:39 +08:00
2025-05-22 15:26:43 +08:00
# print(sorted_logits.detach().cpu().numpy())
# print(sorted_indices.detach().cpu().numpy())