Witllm/qwen/research_attention.py

77 lines
2.5 KiB
Python

import torch
import sys
import math
from modelscope import snapshot_download
from transformers import AutoTokenizer
from transformers import AutoConfig
from modeling_qwen import QWenLMHeadModel
from modeling_qwen import QwenRunner
import torch.nn.functional as F
sys.path.append("..")
from tools import show
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
config, kwargs = AutoConfig.from_pretrained(
"./",
return_unused_kwargs=True,
trust_remote_code=True,
code_revision=None,
_commit_hash=None,
)
model = QWenLMHeadModel(config)
print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir).cuda()
model = model.eval()
class ResearchRunner(QwenRunner):
def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight = torch.softmax(attn_weight, dim=-1)
size = query.shape[2]
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
qk = attn_weight * attn_mask
qk = qk[0]
prePath = "./temp/"
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
attn_output = attention.c_proj(context_layer)
return attn_output
runner = ResearchRunner(model)
# 第一轮对话
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
print(decode_tokens)
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user
# 东南亚国家日本的首都是什么市<|im_end|>
# <|im_start|>assistant
# 日本的首都东京。<|im_end|>
# <|endoftext|>
# 日本的首都是东京。<|im_end|>
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
raise ()