Refine model of wit.
This commit is contained in:
parent
f411b1cc5e
commit
b3817f84fe
|
@ -33,16 +33,6 @@ class QWenModel(nn.Module):
|
||||||
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
||||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
|
||||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
|
||||||
tensor = tensor.view(new_shape)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
|
||||||
tensor = tensor.contiguous()
|
|
||||||
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
|
||||||
return tensor.view(new_shape)
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -123,6 +113,11 @@ class QWenLMHeadModel(nn.Module):
|
||||||
t_rot = (t_rot * cos) + (_rotate_half * sin)
|
t_rot = (t_rot * cos) + (_rotate_half * sin)
|
||||||
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
||||||
|
|
||||||
|
def split_heads(self, tensor, num_heads, attn_head_size):
|
||||||
|
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||||
|
tensor = tensor.view(new_shape)
|
||||||
|
return tensor
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
@ -149,9 +144,9 @@ class QWenLMHeadModel(nn.Module):
|
||||||
atten = block.attn
|
atten = block.attn
|
||||||
mixed_x_layer = atten.c_attn(layernorm_output)
|
mixed_x_layer = atten.c_attn(layernorm_output)
|
||||||
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
|
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
|
||||||
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
|
query = self.split_heads(query, atten.num_heads, atten.head_dim)
|
||||||
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
|
key = self.split_heads(key, atten.num_heads, atten.head_dim)
|
||||||
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
|
value = self.split_heads(value, atten.num_heads, atten.head_dim)
|
||||||
|
|
||||||
# pos_emb
|
# pos_emb
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
|
@ -173,9 +168,12 @@ class QWenLMHeadModel(nn.Module):
|
||||||
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
|
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
|
||||||
if self.hook_attention:
|
if self.hook_attention:
|
||||||
self.hook_attention(query, key, causal_mask, index)
|
self.hook_attention(query, key, causal_mask, index)
|
||||||
context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim)
|
attn_output = attn_output.contiguous()
|
||||||
|
new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
|
||||||
|
context_layer = attn_output.view(new_shape)
|
||||||
attn_outputs = block.attn.c_proj(context_layer)
|
attn_outputs = block.attn.c_proj(context_layer)
|
||||||
|
|
||||||
|
# mlp
|
||||||
layernorm_input = attn_outputs + hidden_states
|
layernorm_input = attn_outputs + hidden_states
|
||||||
layernorm_output = block.ln_2(layernorm_input)
|
layernorm_output = block.ln_2(layernorm_input)
|
||||||
a1 = block.mlp.w1(layernorm_output)
|
a1 = block.mlp.w1(layernorm_output)
|
||||||
|
|
Loading…
Reference in New Issue