Refine qwen/research_attention.py.

This commit is contained in:
Colin 2024-01-21 17:54:05 +08:00
parent dab1c94bc6
commit ae6ea67bbe
1 changed files with 18 additions and 17 deletions

View File

@ -1,5 +1,6 @@
import torch
import sys
import math
from modelscope import snapshot_download
from transformers import AutoTokenizer
from transformers import AutoConfig
@ -7,9 +8,10 @@ from transformers import AutoConfig
from modeling_qwen import QWenLMHeadModel
from modeling_qwen import QwenRunner
import torch.nn.functional as F
sys.path.append("..")
from tools import show
from tools import mem_tracker
seed = 4321
torch.manual_seed(seed)
@ -31,29 +33,28 @@ print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir).cuda()
model = model.eval()
class ResearchRunner(QwenRunner):
def forwardAttention(
self,
attention,
hidden_states,
rotary_pos_emb_list=None,
):
query, key, value = self.split_heads(attention, hidden_states)
query, key = self.pos_emb(query, key, rotary_pos_emb_list)
causal_mask = self.build_mask(query)
def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
size = query.shape[1]
qk = q @ k.transpose(-2, -1)
qk = qk[0]
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight = torch.softmax(attn_weight, dim=-1)
size = query.shape[2]
qk = attn_weight[0]
prePath = "./temp/"
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
return self.attention(attention, query, key, value, causal_mask)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
attn_output = attention.c_proj(context_layer)
return attn_output
runner = ResearchRunner(model)