Refine model of wit.
This commit is contained in:
parent
f411b1cc5e
commit
b3817f84fe
|
@ -33,16 +33,6 @@ class QWenModel(nn.Module):
|
|||
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
||||
self.index = index
|
||||
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
return tensor
|
||||
|
||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||
tensor = tensor.contiguous()
|
||||
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
||||
return tensor.view(new_shape)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -123,6 +113,11 @@ class QWenLMHeadModel(nn.Module):
|
|||
t_rot = (t_rot * cos) + (_rotate_half * sin)
|
||||
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
||||
|
||||
def split_heads(self, tensor, num_heads, attn_head_size):
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
return tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
@ -149,9 +144,9 @@ class QWenLMHeadModel(nn.Module):
|
|||
atten = block.attn
|
||||
mixed_x_layer = atten.c_attn(layernorm_output)
|
||||
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
|
||||
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
|
||||
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
|
||||
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
|
||||
query = self.split_heads(query, atten.num_heads, atten.head_dim)
|
||||
key = self.split_heads(key, atten.num_heads, atten.head_dim)
|
||||
value = self.split_heads(value, atten.num_heads, atten.head_dim)
|
||||
|
||||
# pos_emb
|
||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||
|
@ -173,9 +168,12 @@ class QWenLMHeadModel(nn.Module):
|
|||
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
|
||||
if self.hook_attention:
|
||||
self.hook_attention(query, key, causal_mask, index)
|
||||
context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim)
|
||||
attn_output = attn_output.contiguous()
|
||||
new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
|
||||
context_layer = attn_output.view(new_shape)
|
||||
attn_outputs = block.attn.c_proj(context_layer)
|
||||
|
||||
# mlp
|
||||
layernorm_input = attn_outputs + hidden_states
|
||||
layernorm_output = block.ln_2(layernorm_input)
|
||||
a1 = block.mlp.w1(layernorm_output)
|
||||
|
|
Loading…
Reference in New Issue