import torch import sys import math from modelscope import snapshot_download from transformers import AutoTokenizer from transformers import AutoConfig from modeling_qwen import QWenLMHeadModel from modeling_qwen import QwenRunner import torch.nn.functional as F sys.path.append("..") from tools import show seed = 4321 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) model_dir = snapshot_download("qwen/Qwen-1_8B-Chat") # model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat" config, kwargs = AutoConfig.from_pretrained( "./", return_unused_kwargs=True, trust_remote_code=True, code_revision=None, _commit_hash=None, ) model = QWenLMHeadModel(config) print(model) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) model = model.from_pretrained(model_dir).cuda() model = model.eval() class ResearchRunner(QwenRunner): def attention(self, attention, query, key, value, causal_mask): query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) scale_factor = 1 / math.sqrt(query.size(-1)) attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight = torch.softmax(attn_weight, dim=-1) size = query.shape[2] attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) qk = attn_weight * attn_mask qk = qk[0] prePath = "./temp/" show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png") attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim) attn_output = attention.c_proj(context_layer) return attn_output runner = ResearchRunner(model) # 第一轮对话 response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是") print(decode_tokens) # <|im_start|>system # You are a helpful assistant.<|im_end|> # <|im_start|>user # 东南亚国家日本的首都是什么市<|im_end|> # <|im_start|>assistant # 日本的首都东京。<|im_end|> # <|endoftext|> # 日本的首都是东京。<|im_end|> if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""": raise ()