Refine qwen/research_attention.py.
This commit is contained in:
parent
dab1c94bc6
commit
ae6ea67bbe
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import sys
|
||||
import math
|
||||
from modelscope import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoConfig
|
||||
|
@ -7,9 +8,10 @@ from transformers import AutoConfig
|
|||
from modeling_qwen import QWenLMHeadModel
|
||||
from modeling_qwen import QwenRunner
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
sys.path.append("..")
|
||||
from tools import show
|
||||
from tools import mem_tracker
|
||||
|
||||
seed = 4321
|
||||
torch.manual_seed(seed)
|
||||
|
@ -31,29 +33,28 @@ print(model)
|
|||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
model = model.from_pretrained(model_dir).cuda()
|
||||
|
||||
model = model.eval()
|
||||
|
||||
|
||||
class ResearchRunner(QwenRunner):
|
||||
def forwardAttention(
|
||||
self,
|
||||
attention,
|
||||
hidden_states,
|
||||
rotary_pos_emb_list=None,
|
||||
):
|
||||
query, key, value = self.split_heads(attention, hidden_states)
|
||||
query, key = self.pos_emb(query, key, rotary_pos_emb_list)
|
||||
causal_mask = self.build_mask(query)
|
||||
def attention(self, attention, query, key, value, causal_mask):
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
|
||||
q = query.permute(0, 2, 1, 3)
|
||||
k = key.permute(0, 2, 1, 3)
|
||||
size = query.shape[1]
|
||||
qk = q @ k.transpose(-2, -1)
|
||||
qk = qk[0]
|
||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
|
||||
size = query.shape[2]
|
||||
qk = attn_weight[0]
|
||||
prePath = "./temp/"
|
||||
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
|
||||
return self.attention(attention, query, key, value, causal_mask)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
||||
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
|
||||
attn_output = attention.c_proj(context_layer)
|
||||
return attn_output
|
||||
|
||||
|
||||
runner = ResearchRunner(model)
|
||||
|
|
Loading…
Reference in New Issue