Witllm/qwen/research_attention.py

154 lines
4.7 KiB
Python

import torch
import sys
import math
from modelscope import snapshot_download
from transformers import AutoTokenizer
from transformers import AutoConfig
from modeling_qwen import QWenLMHeadModel
from modeling_qwen import QwenRunner
import torch.nn.functional as F
from qwen_generation_utils import (
make_context,
decode_tokens,
)
sys.path.append("..")
from tools import show
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
config, kwargs = AutoConfig.from_pretrained(
"./",
return_unused_kwargs=True,
trust_remote_code=True,
code_revision=None,
_commit_hash=None,
)
model = QWenLMHeadModel(config)
print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir)
model = model.eval()
def Dump_tokens_list(model):
tokens = []
for token in range(config.eos_token_id):
decoded, response, end_reason = decode_tokens(
[token],
tokenizer,
raw_text_len=0,
context_length=0,
errors="replace",
)
tokens.append(str(token).zfill(7) + ": " + repr(decoded))
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
# Dump_tokens_list(model)
def Dump_lm_head_weight(model):
weight = model.lm_head.weight.cpu() # [151936,2048,]
weight = weight.reshape(64, -1, 2048)
for i in range(64):
sub = weight[i].reshape(-1, 64, 32)
show.DumpTensorToImage(sub, "./temp/lm_head_" + str(i) + "_2374_2048.png")
# Dump_lm_head_weight(model)
def DumpQK(query, key, causal_mask, index):
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
attn_weight = attn_weight * attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
size = query.shape[2]
qk = attn_weight[0]
# prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
# show.DumpTensorToImage(qk, prePath, GridValue=255)
prePath = "./temp/" + "q@k_sum_seq_" + str(size) + "_layer_" + str(index) + ".png"
show.DumpTensorToImage(qk.sum(0), prePath, GridValue=255)
class ResearchRunner(QwenRunner):
def __init__(self, model):
super().__init__(model)
def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
DumpQK(query, key, causal_mask, attention.index)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
attn_output = attention.c_proj(context_layer)
return attn_output
def prepareInput(self, tokenizer, query, query_assistant, history, system):
start_to = [151644]
n_to = [198]
end_to = [151645]
system_str = "system\nYou are a helpful assistant."
user_str = "user\n" + query
aassistant_str = "assistant\n" + query_assistant
system_token = start_to + tokenizer.encode(system_str, allowed_special=set()) + end_to + n_to
user_token = start_to + tokenizer.encode(user_str, allowed_special=set()) + end_to + n_to
aassistant_token = start_to + tokenizer.encode(aassistant_str, allowed_special=set())
tokens = system_token + user_token + aassistant_token
tokens = user_token + aassistant_token
tokens = start_to + tokenizer.encode("user\n你好\nassistant\n", allowed_special=set())
return "", tokens
runner = ResearchRunner(model)
# 第一轮对话
# output_ids, history, decoded = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
# print(decoded)
output_ids, history, decoded = runner.Chat(tokenizer, "你好!!", "")
print(decoded)
tokens = []
for i, token in enumerate(output_ids):
de = tokenizer.decode([token])
de = str(i + 1).zfill(3) + " : " + repr(de)
tokens.append(de)
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user
# 东南亚国家日本的首都是什么市<|im_end|>
# <|im_start|>assistant
# 日本的首都东京。<|im_end|>
# <|endoftext|>
show.DumpListToFile(tokens, "./temp/token_decode_list.txt")
# if decoded.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
# raise ()
# normal (x - mean) / (std + eps) => sum(y)==0
# softmax exp(x) / sum(exp(x)) => 0 < y < 1 sum(y)==1