2024-01-21 16:46:00 +08:00
|
|
|
import torch
|
|
|
|
import sys
|
2024-01-21 17:54:05 +08:00
|
|
|
import math
|
2024-01-21 16:46:00 +08:00
|
|
|
from modelscope import snapshot_download
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
from transformers import AutoConfig
|
|
|
|
|
|
|
|
from modeling_qwen import QWenLMHeadModel
|
|
|
|
from modeling_qwen import QwenRunner
|
|
|
|
|
2024-01-21 17:54:05 +08:00
|
|
|
import torch.nn.functional as F
|
|
|
|
|
2024-01-21 16:46:00 +08:00
|
|
|
sys.path.append("..")
|
|
|
|
from tools import show
|
|
|
|
|
|
|
|
seed = 4321
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
|
|
|
|
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
|
|
|
|
|
|
|
|
config, kwargs = AutoConfig.from_pretrained(
|
|
|
|
"./",
|
|
|
|
return_unused_kwargs=True,
|
|
|
|
trust_remote_code=True,
|
|
|
|
code_revision=None,
|
|
|
|
_commit_hash=None,
|
|
|
|
)
|
|
|
|
model = QWenLMHeadModel(config)
|
|
|
|
|
|
|
|
print(model)
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
|
|
|
model = model.from_pretrained(model_dir).cuda()
|
|
|
|
model = model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
class ResearchRunner(QwenRunner):
|
2024-01-21 17:54:05 +08:00
|
|
|
def attention(self, attention, query, key, value, causal_mask):
|
|
|
|
query = query.permute(0, 2, 1, 3)
|
|
|
|
key = key.permute(0, 2, 1, 3)
|
|
|
|
value = value.permute(0, 2, 1, 3)
|
|
|
|
|
|
|
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
|
|
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
|
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
|
|
size = query.shape[2]
|
2024-01-21 20:50:36 +08:00
|
|
|
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
|
|
|
|
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
|
|
|
qk = attn_weight * attn_mask
|
|
|
|
qk = qk[0]
|
2024-01-21 16:46:00 +08:00
|
|
|
prePath = "./temp/"
|
|
|
|
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
|
2024-01-21 17:54:05 +08:00
|
|
|
|
|
|
|
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
|
|
|
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
|
|
|
|
attn_output = attention.c_proj(context_layer)
|
|
|
|
return attn_output
|
2024-01-21 16:46:00 +08:00
|
|
|
|
|
|
|
|
|
|
|
runner = ResearchRunner(model)
|
|
|
|
|
|
|
|
# 第一轮对话
|
2024-01-21 22:43:16 +08:00
|
|
|
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
|
2024-01-21 16:46:00 +08:00
|
|
|
print(decode_tokens)
|
|
|
|
# <|im_start|>system
|
|
|
|
# You are a helpful assistant.<|im_end|>
|
|
|
|
# <|im_start|>user
|
|
|
|
# 东南亚国家日本的首都是什么市<|im_end|>
|
|
|
|
# <|im_start|>assistant
|
|
|
|
# 日本的首都东京。<|im_end|>
|
|
|
|
# <|endoftext|>
|
2024-01-21 22:43:16 +08:00
|
|
|
# 日本的首都是东京。<|im_end|>
|
2024-01-21 16:46:00 +08:00
|
|
|
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
|
|
|
raise ()
|