Witllm/qwen/research_attention.py

74 lines
2.0 KiB
Python
Raw Normal View History

2024-01-21 16:46:00 +08:00
import torch
import sys
from modelscope import snapshot_download
from transformers import AutoTokenizer
from transformers import AutoConfig
from modeling_qwen import QWenLMHeadModel
from modeling_qwen import QwenRunner
sys.path.append("..")
from tools import show
from tools import mem_tracker
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
config, kwargs = AutoConfig.from_pretrained(
"./",
return_unused_kwargs=True,
trust_remote_code=True,
code_revision=None,
_commit_hash=None,
)
model = QWenLMHeadModel(config)
print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir).cuda()
model = model.eval()
class ResearchRunner(QwenRunner):
def forwardAttention(
self,
attention,
hidden_states,
rotary_pos_emb_list=None,
):
query, key, value = self.split_heads(attention, hidden_states)
query, key = self.pos_emb(query, key, rotary_pos_emb_list)
causal_mask = self.build_mask(query)
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
size = query.shape[1]
qk = q @ k.transpose(-2, -1)
qk = qk[0]
prePath = "./temp/"
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
return self.attention(attention, query, key, value, causal_mask)
runner = ResearchRunner(model)
# 第一轮对话
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "")
print(decode_tokens)
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user
# 东南亚国家日本的首都是什么市<|im_end|>
# <|im_start|>assistant
# 日本的首都东京。<|im_end|>
# <|endoftext|>
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
raise ()