diff --git a/wit/model/modeling_wit.py b/wit/model/modeling_wit.py index ab7741d..b251c5b 100644 --- a/wit/model/modeling_wit.py +++ b/wit/model/modeling_wit.py @@ -107,6 +107,21 @@ class QWenLMHeadModel(nn.Module): self.config = config self.transformer = QWenModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.hook_attention = None + + def apply_rotary_pos_emb(self, t, freqs): + rot_dim = freqs[0].shape[-1] + cos, sin = freqs + t_float = t.float() + t_rot = t_float[..., :rot_dim] + t_pass = t_float[..., rot_dim:] + + x = rearrange(t_rot, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + _rotate_half = torch.cat((-x2, x1), dim=-1) + + t_rot = (t_rot * cos) + (_rotate_half * sin) + return torch.cat((t_rot, t_pass), dim=-1).type_as(t) def forward( self, @@ -128,8 +143,47 @@ class QWenLMHeadModel(nn.Module): hidden_states = transfm.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) - for block in transfm.h: - hidden_states = self.forwardBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list) + for index, block in enumerate(transfm.h): + layernorm_output = block.ln_1(hidden_states) + # split_heads + atten = block.attn + mixed_x_layer = atten.c_attn(layernorm_output) + query, key, value = mixed_x_layer.split(atten.split_size, dim=2) + query = atten._split_heads(query, atten.num_heads, atten.head_dim) + key = atten._split_heads(key, atten.num_heads, atten.head_dim) + value = atten._split_heads(value, atten.num_heads, atten.head_dim) + + # pos_emb + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0]) + key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1]) + + # build_mask + size = query.size(1) + causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view( + 1, 1, size, size + ) + + # attention + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2) + if self.hook_attention: + self.hook_attention(query, key, causal_mask, index) + context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim) + attn_outputs = block.attn.c_proj(context_layer) + + layernorm_input = attn_outputs + hidden_states + layernorm_output = block.ln_2(layernorm_input) + a1 = block.mlp.w1(layernorm_output) + a2 = block.mlp.w2(layernorm_output) + intermediate_parallel = a1 * F.silu(a2) + mlp_output = block.mlp.c_proj(intermediate_parallel) + + hidden_states = layernorm_input + mlp_output hidden_states = transfm.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) @@ -145,66 +199,6 @@ class QWenLMHeadModel(nn.Module): mask = shift_labels < self.config.vocab_size shift_labels = shift_labels[mask] shift_logits = shift_logits[mask] - # m = torch.max(shift_logits, 1).indices.cpu().numpy() - # ll = shift_labels.cpu().numpy() loss = CrossEntropyLoss()(shift_logits, shift_labels) return lm_logits, loss - - def apply_rotary_pos_emb(self, t, freqs): - rot_dim = freqs[0].shape[-1] - cos, sin = freqs - t_float = t.float() - t_rot = t_float[..., :rot_dim] - t_pass = t_float[..., rot_dim:] - - x = rearrange(t_rot, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - _rotate_half = torch.cat((-x2, x1), dim=-1) - - t_rot = (t_rot * cos) + (_rotate_half * sin) - return torch.cat((t_rot, t_pass), dim=-1).type_as(t) - - def forwardBlock( - self, - block, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, - ): - layernorm_output = block.ln_1(hidden_states) - - # split_heads - atten = block.attn - mixed_x_layer = atten.c_attn(layernorm_output) - query, key, value = mixed_x_layer.split(atten.split_size, dim=2) - query = atten._split_heads(query, atten.num_heads, atten.head_dim) - key = atten._split_heads(key, atten.num_heads, atten.head_dim) - value = atten._split_heads(value, atten.num_heads, atten.head_dim) - # pos_emb - rotary_pos_emb = rotary_pos_emb_list[0] - rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0]) - key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1]) - - # build_mask - size = query.size(1) - causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size) - - # attention - q = query.permute(0, 2, 1, 3) - k = key.permute(0, 2, 1, 3) - v = value.permute(0, 2, 1, 3) - attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2) - context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim) - attn_outputs = block.attn.c_proj(context_layer) - - layernorm_input = attn_outputs + hidden_states - layernorm_output = block.ln_2(layernorm_input) - a1 = block.mlp.w1(layernorm_output) - a2 = block.mlp.w2(layernorm_output) - intermediate_parallel = a1 * F.silu(a2) - mlp_output = block.mlp.c_proj(intermediate_parallel) - - hidden_states = layernorm_input + mlp_output - return hidden_states diff --git a/wit/query_block_output.py b/wit/query_block_output.py new file mode 100644 index 0000000..1b04e45 --- /dev/null +++ b/wit/query_block_output.py @@ -0,0 +1,51 @@ +import torch + +from model.qwen_module import QwenModule +from model.qwen_module import ModelRunner +import numpy as np + +import math +import sys + +sys.path.append("..") +from tools import show + + +import dataset.dataset as ds + +if __name__ == "__main__": + + # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" + checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" + checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt" + checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" + + qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) + qwen.eval() + conf = qwen.config + torch.manual_seed(conf.seed) + np.random.seed(conf.seed) + runner = ModelRunner(qwen.llm) + + def DumpQK(query, key, causal_mask, index): + size = query.shape[2] + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) + attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) + attn_weight = attn_weight * attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = attn_weight * attn_mask + qk = attn_weight[0] + prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" + show.DumpTensorToImage(qk, prePath, GridValue=255) + # qk_seq.append(qk) + # qk_index = size + + qwen.llm.hook_attention = DumpQK + + batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64) + sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) + + print(sorted_logits.detach().cpu().numpy()) + print(sorted_indices.detach().cpu().numpy())