Add hook of attention for query qkv.
This commit is contained in:
		
							parent
							
								
									3eea09d78c
								
							
						
					
					
						commit
						002f132818
					
				| 
						 | 
					@ -107,6 +107,21 @@ class QWenLMHeadModel(nn.Module):
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
        self.transformer = QWenModel(config)
 | 
					        self.transformer = QWenModel(config)
 | 
				
			||||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
					        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
				
			||||||
 | 
					        self.hook_attention = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def apply_rotary_pos_emb(self, t, freqs):
 | 
				
			||||||
 | 
					        rot_dim = freqs[0].shape[-1]
 | 
				
			||||||
 | 
					        cos, sin = freqs
 | 
				
			||||||
 | 
					        t_float = t.float()
 | 
				
			||||||
 | 
					        t_rot = t_float[..., :rot_dim]
 | 
				
			||||||
 | 
					        t_pass = t_float[..., rot_dim:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = rearrange(t_rot, "... (j d) -> ... j d", j=2)
 | 
				
			||||||
 | 
					        x1, x2 = x.unbind(dim=-2)
 | 
				
			||||||
 | 
					        _rotate_half = torch.cat((-x2, x1), dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        t_rot = (t_rot * cos) + (_rotate_half * sin)
 | 
				
			||||||
 | 
					        return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
| 
						 | 
					@ -128,8 +143,47 @@ class QWenLMHeadModel(nn.Module):
 | 
				
			||||||
        hidden_states = transfm.drop(hidden_states)
 | 
					        hidden_states = transfm.drop(hidden_states)
 | 
				
			||||||
        output_shape = input_shape + (hidden_states.size(-1),)
 | 
					        output_shape = input_shape + (hidden_states.size(-1),)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for block in transfm.h:
 | 
					        for index, block in enumerate(transfm.h):
 | 
				
			||||||
            hidden_states = self.forwardBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
 | 
					            layernorm_output = block.ln_1(hidden_states)
 | 
				
			||||||
 | 
					            # split_heads
 | 
				
			||||||
 | 
					            atten = block.attn
 | 
				
			||||||
 | 
					            mixed_x_layer = atten.c_attn(layernorm_output)
 | 
				
			||||||
 | 
					            query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
 | 
				
			||||||
 | 
					            query = atten._split_heads(query, atten.num_heads, atten.head_dim)
 | 
				
			||||||
 | 
					            key = atten._split_heads(key, atten.num_heads, atten.head_dim)
 | 
				
			||||||
 | 
					            value = atten._split_heads(value, atten.num_heads, atten.head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # pos_emb
 | 
				
			||||||
 | 
					            rotary_pos_emb = rotary_pos_emb_list[0]
 | 
				
			||||||
 | 
					            rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
 | 
				
			||||||
 | 
					            rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
				
			||||||
 | 
					            query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
 | 
				
			||||||
 | 
					            key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # build_mask
 | 
				
			||||||
 | 
					            size = query.size(1)
 | 
				
			||||||
 | 
					            causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(
 | 
				
			||||||
 | 
					                1, 1, size, size
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # attention
 | 
				
			||||||
 | 
					            q = query.permute(0, 2, 1, 3)
 | 
				
			||||||
 | 
					            k = key.permute(0, 2, 1, 3)
 | 
				
			||||||
 | 
					            v = value.permute(0, 2, 1, 3)
 | 
				
			||||||
 | 
					            attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
 | 
				
			||||||
 | 
					            if self.hook_attention:
 | 
				
			||||||
 | 
					                self.hook_attention(query, key, causal_mask, index)
 | 
				
			||||||
 | 
					            context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim)
 | 
				
			||||||
 | 
					            attn_outputs = block.attn.c_proj(context_layer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            layernorm_input = attn_outputs + hidden_states
 | 
				
			||||||
 | 
					            layernorm_output = block.ln_2(layernorm_input)
 | 
				
			||||||
 | 
					            a1 = block.mlp.w1(layernorm_output)
 | 
				
			||||||
 | 
					            a2 = block.mlp.w2(layernorm_output)
 | 
				
			||||||
 | 
					            intermediate_parallel = a1 * F.silu(a2)
 | 
				
			||||||
 | 
					            mlp_output = block.mlp.c_proj(intermediate_parallel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            hidden_states = layernorm_input + mlp_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        hidden_states = transfm.ln_f(hidden_states)
 | 
					        hidden_states = transfm.ln_f(hidden_states)
 | 
				
			||||||
        hidden_states = hidden_states.view(output_shape)
 | 
					        hidden_states = hidden_states.view(output_shape)
 | 
				
			||||||
| 
						 | 
					@ -145,66 +199,6 @@ class QWenLMHeadModel(nn.Module):
 | 
				
			||||||
            mask = shift_labels < self.config.vocab_size
 | 
					            mask = shift_labels < self.config.vocab_size
 | 
				
			||||||
            shift_labels = shift_labels[mask]
 | 
					            shift_labels = shift_labels[mask]
 | 
				
			||||||
            shift_logits = shift_logits[mask]
 | 
					            shift_logits = shift_logits[mask]
 | 
				
			||||||
            # m = torch.max(shift_logits, 1).indices.cpu().numpy()
 | 
					 | 
				
			||||||
            # ll = shift_labels.cpu().numpy()
 | 
					 | 
				
			||||||
            loss = CrossEntropyLoss()(shift_logits, shift_labels)
 | 
					            loss = CrossEntropyLoss()(shift_logits, shift_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return lm_logits, loss
 | 
					        return lm_logits, loss
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def apply_rotary_pos_emb(self, t, freqs):
 | 
					 | 
				
			||||||
        rot_dim = freqs[0].shape[-1]
 | 
					 | 
				
			||||||
        cos, sin = freqs
 | 
					 | 
				
			||||||
        t_float = t.float()
 | 
					 | 
				
			||||||
        t_rot = t_float[..., :rot_dim]
 | 
					 | 
				
			||||||
        t_pass = t_float[..., rot_dim:]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        x = rearrange(t_rot, "... (j d) -> ... j d", j=2)
 | 
					 | 
				
			||||||
        x1, x2 = x.unbind(dim=-2)
 | 
					 | 
				
			||||||
        _rotate_half = torch.cat((-x2, x1), dim=-1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        t_rot = (t_rot * cos) + (_rotate_half * sin)
 | 
					 | 
				
			||||||
        return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forwardBlock(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        block,
 | 
					 | 
				
			||||||
        hidden_states: Optional[Tuple[torch.FloatTensor]],
 | 
					 | 
				
			||||||
        rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        layernorm_output = block.ln_1(hidden_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # split_heads
 | 
					 | 
				
			||||||
        atten = block.attn
 | 
					 | 
				
			||||||
        mixed_x_layer = atten.c_attn(layernorm_output)
 | 
					 | 
				
			||||||
        query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
 | 
					 | 
				
			||||||
        query = atten._split_heads(query, atten.num_heads, atten.head_dim)
 | 
					 | 
				
			||||||
        key = atten._split_heads(key, atten.num_heads, atten.head_dim)
 | 
					 | 
				
			||||||
        value = atten._split_heads(value, atten.num_heads, atten.head_dim)
 | 
					 | 
				
			||||||
        # pos_emb
 | 
					 | 
				
			||||||
        rotary_pos_emb = rotary_pos_emb_list[0]
 | 
					 | 
				
			||||||
        rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
 | 
					 | 
				
			||||||
        rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
					 | 
				
			||||||
        query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
 | 
					 | 
				
			||||||
        key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # build_mask
 | 
					 | 
				
			||||||
        size = query.size(1)
 | 
					 | 
				
			||||||
        causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # attention
 | 
					 | 
				
			||||||
        q = query.permute(0, 2, 1, 3)
 | 
					 | 
				
			||||||
        k = key.permute(0, 2, 1, 3)
 | 
					 | 
				
			||||||
        v = value.permute(0, 2, 1, 3)
 | 
					 | 
				
			||||||
        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
 | 
					 | 
				
			||||||
        context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim)
 | 
					 | 
				
			||||||
        attn_outputs = block.attn.c_proj(context_layer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        layernorm_input = attn_outputs + hidden_states
 | 
					 | 
				
			||||||
        layernorm_output = block.ln_2(layernorm_input)
 | 
					 | 
				
			||||||
        a1 = block.mlp.w1(layernorm_output)
 | 
					 | 
				
			||||||
        a2 = block.mlp.w2(layernorm_output)
 | 
					 | 
				
			||||||
        intermediate_parallel = a1 * F.silu(a2)
 | 
					 | 
				
			||||||
        mlp_output = block.mlp.c_proj(intermediate_parallel)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        hidden_states = layernorm_input + mlp_output
 | 
					 | 
				
			||||||
        return hidden_states
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,51 @@
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from model.qwen_module import QwenModule
 | 
				
			||||||
 | 
					from model.qwen_module import ModelRunner
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					sys.path.append("..")
 | 
				
			||||||
 | 
					from tools import show
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import dataset.dataset as ds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
 | 
				
			||||||
 | 
					    checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
 | 
				
			||||||
 | 
					    checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
 | 
				
			||||||
 | 
					    checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
 | 
				
			||||||
 | 
					    qwen.eval()
 | 
				
			||||||
 | 
					    conf = qwen.config
 | 
				
			||||||
 | 
					    torch.manual_seed(conf.seed)
 | 
				
			||||||
 | 
					    np.random.seed(conf.seed)
 | 
				
			||||||
 | 
					    runner = ModelRunner(qwen.llm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def DumpQK(query, key, causal_mask, index):
 | 
				
			||||||
 | 
					        size = query.shape[2]
 | 
				
			||||||
 | 
					        scale_factor = 1 / math.sqrt(query.size(-1))
 | 
				
			||||||
 | 
					        attn_weight = query @ key.transpose(-2, -1) * scale_factor
 | 
				
			||||||
 | 
					        attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
 | 
				
			||||||
 | 
					        attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
 | 
				
			||||||
 | 
					        attn_weight = attn_weight * attn_mask
 | 
				
			||||||
 | 
					        attn_weight = torch.softmax(attn_weight, dim=-1)
 | 
				
			||||||
 | 
					        attn_weight = attn_weight * attn_mask
 | 
				
			||||||
 | 
					        qk = attn_weight[0]
 | 
				
			||||||
 | 
					        prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
 | 
				
			||||||
 | 
					        show.DumpTensorToImage(qk, prePath, GridValue=255)
 | 
				
			||||||
 | 
					        # qk_seq.append(qk)
 | 
				
			||||||
 | 
					        # qk_index = size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    qwen.llm.hook_attention = DumpQK
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
 | 
				
			||||||
 | 
					    sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(sorted_logits.detach().cpu().numpy())
 | 
				
			||||||
 | 
					    print(sorted_indices.detach().cpu().numpy())
 | 
				
			||||||
		Loading…
	
		Reference in New Issue