Refine model of qwen.

This commit is contained in:
Colin 2024-01-07 22:49:21 +08:00
parent 94ecf0f561
commit 69cb525ab0
1 changed files with 3 additions and 21 deletions

View File

@ -4,16 +4,13 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy import copy
import importlib
import math import math
import inspect import inspect
import pathlib
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
import warnings
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
@ -30,6 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
from torch import nn from torch import nn
from einops import rearrange
from configuration_qwen import QWenConfig from configuration_qwen import QWenConfig
from qwen_generation_utils import ( from qwen_generation_utils import (
@ -63,11 +61,9 @@ class QWenAttention(nn.Module):
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias) self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
self.use_dynamic_ntk = config.use_dynamic_ntk self.use_dynamic_ntk = config.use_dynamic_ntk
self.use_logn_attn = config.use_logn_attn
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)] logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
logn_tensor = torch.tensor(logn_list)[None, :, None, None] logn_tensor = torch.tensor(logn_list)[None, :, None, None]
@ -120,7 +116,7 @@ class QWenAttention(nn.Module):
present = (key, value) present = (key, value)
key_size = key.size(1) key_size = key.size(1)
if key_size > self.seq_length and self.use_logn_attn and not self.training: if key_size > self.seq_length and not self.training:
seq_start = key.size(1) - query.size(1) seq_start = key.size(1) - query.size(1)
seq_end = key.size(1) seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
@ -143,14 +139,11 @@ class QWenAttention(nn.Module):
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
else: else:
attention_mask = causal_mask attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer) attn_output = self.c_proj(context_layer)
outputs = (attn_output, present) outputs = (attn_output, present)
return outputs return outputs
@ -184,7 +177,6 @@ class QWenBlock(nn.Module):
hidden_size, hidden_size,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.mlp = QWenMLP(config) self.mlp = QWenMLP(config)
def forward( def forward(
@ -202,22 +194,16 @@ class QWenBlock(nn.Module):
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
) )
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
outputs = attn_outputs[1:] outputs = attn_outputs[1:]
residual = hidden_states residual = hidden_states
layernorm_input = attn_output + residual layernorm_input = attn_output + residual
layernorm_output = self.ln_2(layernorm_input) layernorm_output = self.ln_2(layernorm_input)
residual = layernorm_input residual = layernorm_input
mlp_output = self.mlp(layernorm_output) mlp_output = self.mlp(layernorm_output)
hidden_states = residual + mlp_output hidden_states = residual + mlp_output
outputs = (hidden_states,) + outputs outputs = (hidden_states,) + outputs
return outputs return outputs
@ -732,8 +718,6 @@ class RotaryEmbedding(torch.nn.Module):
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
from einops import rearrange
emb = rearrange(emb, "n d -> 1 n 1 d") emb = rearrange(emb, "n d -> 1 n 1 d")
cos, sin = emb.cos(), emb.sin() cos, sin = emb.cos(), emb.sin()
@ -746,8 +730,6 @@ class RotaryEmbedding(torch.nn.Module):
def _rotate_half(x): def _rotate_half(x):
from einops import rearrange
x = rearrange(x, "... (j d) -> ... j d", j=2) x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2) x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)