Add hook of attention for query qkv.
This commit is contained in:
parent
3eea09d78c
commit
002f132818
|
@ -107,6 +107,21 @@ class QWenLMHeadModel(nn.Module):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.transformer = QWenModel(config)
|
self.transformer = QWenModel(config)
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
self.hook_attention = None
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(self, t, freqs):
|
||||||
|
rot_dim = freqs[0].shape[-1]
|
||||||
|
cos, sin = freqs
|
||||||
|
t_float = t.float()
|
||||||
|
t_rot = t_float[..., :rot_dim]
|
||||||
|
t_pass = t_float[..., rot_dim:]
|
||||||
|
|
||||||
|
x = rearrange(t_rot, "... (j d) -> ... j d", j=2)
|
||||||
|
x1, x2 = x.unbind(dim=-2)
|
||||||
|
_rotate_half = torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
t_rot = (t_rot * cos) + (_rotate_half * sin)
|
||||||
|
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -128,8 +143,47 @@ class QWenLMHeadModel(nn.Module):
|
||||||
hidden_states = transfm.drop(hidden_states)
|
hidden_states = transfm.drop(hidden_states)
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
for block in transfm.h:
|
for index, block in enumerate(transfm.h):
|
||||||
hidden_states = self.forwardBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
|
layernorm_output = block.ln_1(hidden_states)
|
||||||
|
# split_heads
|
||||||
|
atten = block.attn
|
||||||
|
mixed_x_layer = atten.c_attn(layernorm_output)
|
||||||
|
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
|
||||||
|
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
|
||||||
|
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
|
||||||
|
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
|
||||||
|
|
||||||
|
# pos_emb
|
||||||
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
|
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
|
||||||
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
|
query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
|
||||||
|
key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
|
||||||
|
|
||||||
|
# build_mask
|
||||||
|
size = query.size(1)
|
||||||
|
causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(
|
||||||
|
1, 1, size, size
|
||||||
|
)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
q = query.permute(0, 2, 1, 3)
|
||||||
|
k = key.permute(0, 2, 1, 3)
|
||||||
|
v = value.permute(0, 2, 1, 3)
|
||||||
|
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
|
||||||
|
if self.hook_attention:
|
||||||
|
self.hook_attention(query, key, causal_mask, index)
|
||||||
|
context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim)
|
||||||
|
attn_outputs = block.attn.c_proj(context_layer)
|
||||||
|
|
||||||
|
layernorm_input = attn_outputs + hidden_states
|
||||||
|
layernorm_output = block.ln_2(layernorm_input)
|
||||||
|
a1 = block.mlp.w1(layernorm_output)
|
||||||
|
a2 = block.mlp.w2(layernorm_output)
|
||||||
|
intermediate_parallel = a1 * F.silu(a2)
|
||||||
|
mlp_output = block.mlp.c_proj(intermediate_parallel)
|
||||||
|
|
||||||
|
hidden_states = layernorm_input + mlp_output
|
||||||
|
|
||||||
hidden_states = transfm.ln_f(hidden_states)
|
hidden_states = transfm.ln_f(hidden_states)
|
||||||
hidden_states = hidden_states.view(output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
@ -145,66 +199,6 @@ class QWenLMHeadModel(nn.Module):
|
||||||
mask = shift_labels < self.config.vocab_size
|
mask = shift_labels < self.config.vocab_size
|
||||||
shift_labels = shift_labels[mask]
|
shift_labels = shift_labels[mask]
|
||||||
shift_logits = shift_logits[mask]
|
shift_logits = shift_logits[mask]
|
||||||
# m = torch.max(shift_logits, 1).indices.cpu().numpy()
|
|
||||||
# ll = shift_labels.cpu().numpy()
|
|
||||||
loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
loss = CrossEntropyLoss()(shift_logits, shift_labels)
|
||||||
|
|
||||||
return lm_logits, loss
|
return lm_logits, loss
|
||||||
|
|
||||||
def apply_rotary_pos_emb(self, t, freqs):
|
|
||||||
rot_dim = freqs[0].shape[-1]
|
|
||||||
cos, sin = freqs
|
|
||||||
t_float = t.float()
|
|
||||||
t_rot = t_float[..., :rot_dim]
|
|
||||||
t_pass = t_float[..., rot_dim:]
|
|
||||||
|
|
||||||
x = rearrange(t_rot, "... (j d) -> ... j d", j=2)
|
|
||||||
x1, x2 = x.unbind(dim=-2)
|
|
||||||
_rotate_half = torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
t_rot = (t_rot * cos) + (_rotate_half * sin)
|
|
||||||
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
|
||||||
|
|
||||||
def forwardBlock(
|
|
||||||
self,
|
|
||||||
block,
|
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
|
||||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
|
||||||
):
|
|
||||||
layernorm_output = block.ln_1(hidden_states)
|
|
||||||
|
|
||||||
# split_heads
|
|
||||||
atten = block.attn
|
|
||||||
mixed_x_layer = atten.c_attn(layernorm_output)
|
|
||||||
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
|
|
||||||
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
|
|
||||||
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
|
|
||||||
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
|
|
||||||
# pos_emb
|
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
|
||||||
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
|
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
|
||||||
query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
|
|
||||||
key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
|
|
||||||
|
|
||||||
# build_mask
|
|
||||||
size = query.size(1)
|
|
||||||
causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size)
|
|
||||||
|
|
||||||
# attention
|
|
||||||
q = query.permute(0, 2, 1, 3)
|
|
||||||
k = key.permute(0, 2, 1, 3)
|
|
||||||
v = value.permute(0, 2, 1, 3)
|
|
||||||
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
|
|
||||||
context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim)
|
|
||||||
attn_outputs = block.attn.c_proj(context_layer)
|
|
||||||
|
|
||||||
layernorm_input = attn_outputs + hidden_states
|
|
||||||
layernorm_output = block.ln_2(layernorm_input)
|
|
||||||
a1 = block.mlp.w1(layernorm_output)
|
|
||||||
a2 = block.mlp.w2(layernorm_output)
|
|
||||||
intermediate_parallel = a1 * F.silu(a2)
|
|
||||||
mlp_output = block.mlp.c_proj(intermediate_parallel)
|
|
||||||
|
|
||||||
hidden_states = layernorm_input + mlp_output
|
|
||||||
return hidden_states
|
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from model.qwen_module import QwenModule
|
||||||
|
from model.qwen_module import ModelRunner
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("..")
|
||||||
|
from tools import show
|
||||||
|
|
||||||
|
|
||||||
|
import dataset.dataset as ds
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
|
||||||
|
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
||||||
|
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
||||||
|
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
||||||
|
|
||||||
|
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
|
qwen.eval()
|
||||||
|
conf = qwen.config
|
||||||
|
torch.manual_seed(conf.seed)
|
||||||
|
np.random.seed(conf.seed)
|
||||||
|
runner = ModelRunner(qwen.llm)
|
||||||
|
|
||||||
|
def DumpQK(query, key, causal_mask, index):
|
||||||
|
size = query.shape[2]
|
||||||
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
|
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
|
||||||
|
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
||||||
|
attn_weight = attn_weight * attn_mask
|
||||||
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||||
|
attn_weight = attn_weight * attn_mask
|
||||||
|
qk = attn_weight[0]
|
||||||
|
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
||||||
|
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||||
|
# qk_seq.append(qk)
|
||||||
|
# qk_index = size
|
||||||
|
|
||||||
|
qwen.llm.hook_attention = DumpQK
|
||||||
|
|
||||||
|
batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
|
||||||
|
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||||
|
|
||||||
|
print(sorted_logits.detach().cpu().numpy())
|
||||||
|
print(sorted_indices.detach().cpu().numpy())
|
Loading…
Reference in New Issue