diff --git a/wit/model/modeling_wit.py b/wit/model/modeling_wit.py index b251c5b..6d90e44 100644 --- a/wit/model/modeling_wit.py +++ b/wit/model/modeling_wit.py @@ -33,16 +33,6 @@ class QWenModel(nn.Module): self.attn_dropout = nn.Dropout(config.attn_dropout_prob) self.index = index - def _split_heads(self, tensor, num_heads, attn_head_size): - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor - - def _merge_heads(self, tensor, num_heads, attn_head_size): - tensor = tensor.contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - class MLP(nn.Module): def __init__(self, config): super().__init__() @@ -123,6 +113,11 @@ class QWenLMHeadModel(nn.Module): t_rot = (t_rot * cos) + (_rotate_half * sin) return torch.cat((t_rot, t_pass), dim=-1).type_as(t) + def split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -149,9 +144,9 @@ class QWenLMHeadModel(nn.Module): atten = block.attn mixed_x_layer = atten.c_attn(layernorm_output) query, key, value = mixed_x_layer.split(atten.split_size, dim=2) - query = atten._split_heads(query, atten.num_heads, atten.head_dim) - key = atten._split_heads(key, atten.num_heads, atten.head_dim) - value = atten._split_heads(value, atten.num_heads, atten.head_dim) + query = self.split_heads(query, atten.num_heads, atten.head_dim) + key = self.split_heads(key, atten.num_heads, atten.head_dim) + value = self.split_heads(value, atten.num_heads, atten.head_dim) # pos_emb rotary_pos_emb = rotary_pos_emb_list[0] @@ -173,9 +168,12 @@ class QWenLMHeadModel(nn.Module): attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2) if self.hook_attention: self.hook_attention(query, key, causal_mask, index) - context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim) + attn_output = attn_output.contiguous() + new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,) + context_layer = attn_output.view(new_shape) attn_outputs = block.attn.c_proj(context_layer) + # mlp layernorm_input = attn_outputs + hidden_states layernorm_output = block.ln_2(layernorm_input) a1 = block.mlp.w1(layernorm_output)