Refine qwen/research_attention.py.
This commit is contained in:
parent
dab1c94bc6
commit
ae6ea67bbe
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
import math
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
@ -7,9 +8,10 @@ from transformers import AutoConfig
|
||||||
from modeling_qwen import QWenLMHeadModel
|
from modeling_qwen import QWenLMHeadModel
|
||||||
from modeling_qwen import QwenRunner
|
from modeling_qwen import QwenRunner
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
sys.path.append("..")
|
sys.path.append("..")
|
||||||
from tools import show
|
from tools import show
|
||||||
from tools import mem_tracker
|
|
||||||
|
|
||||||
seed = 4321
|
seed = 4321
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
@ -31,29 +33,28 @@ print(model)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
model = model.from_pretrained(model_dir).cuda()
|
model = model.from_pretrained(model_dir).cuda()
|
||||||
|
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
|
|
||||||
class ResearchRunner(QwenRunner):
|
class ResearchRunner(QwenRunner):
|
||||||
def forwardAttention(
|
def attention(self, attention, query, key, value, causal_mask):
|
||||||
self,
|
query = query.permute(0, 2, 1, 3)
|
||||||
attention,
|
key = key.permute(0, 2, 1, 3)
|
||||||
hidden_states,
|
value = value.permute(0, 2, 1, 3)
|
||||||
rotary_pos_emb_list=None,
|
|
||||||
):
|
|
||||||
query, key, value = self.split_heads(attention, hidden_states)
|
|
||||||
query, key = self.pos_emb(query, key, rotary_pos_emb_list)
|
|
||||||
causal_mask = self.build_mask(query)
|
|
||||||
|
|
||||||
q = query.permute(0, 2, 1, 3)
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
k = key.permute(0, 2, 1, 3)
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
size = query.shape[1]
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||||
qk = q @ k.transpose(-2, -1)
|
|
||||||
qk = qk[0]
|
size = query.shape[2]
|
||||||
|
qk = attn_weight[0]
|
||||||
prePath = "./temp/"
|
prePath = "./temp/"
|
||||||
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
|
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
|
||||||
return self.attention(attention, query, key, value, causal_mask)
|
|
||||||
|
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
||||||
|
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
|
||||||
|
attn_output = attention.c_proj(context_layer)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
runner = ResearchRunner(model)
|
runner = ResearchRunner(model)
|
||||||
|
|
Loading…
Reference in New Issue