Refine model of qwen.
This commit is contained in:
parent
94ecf0f561
commit
69cb525ab0
|
@ -4,16 +4,13 @@
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import importlib
|
|
||||||
import math
|
import math
|
||||||
import inspect
|
import inspect
|
||||||
import pathlib
|
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import warnings
|
|
||||||
|
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
|
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
|
||||||
|
@ -30,6 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
from configuration_qwen import QWenConfig
|
from configuration_qwen import QWenConfig
|
||||||
from qwen_generation_utils import (
|
from qwen_generation_utils import (
|
||||||
|
@ -63,11 +61,9 @@ class QWenAttention(nn.Module):
|
||||||
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
||||||
|
|
||||||
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
|
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
|
||||||
|
|
||||||
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
|
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
|
||||||
|
|
||||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||||
self.use_logn_attn = config.use_logn_attn
|
|
||||||
|
|
||||||
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
|
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
|
||||||
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
||||||
|
@ -120,7 +116,7 @@ class QWenAttention(nn.Module):
|
||||||
present = (key, value)
|
present = (key, value)
|
||||||
|
|
||||||
key_size = key.size(1)
|
key_size = key.size(1)
|
||||||
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
if key_size > self.seq_length and not self.training:
|
||||||
seq_start = key.size(1) - query.size(1)
|
seq_start = key.size(1) - query.size(1)
|
||||||
seq_end = key.size(1)
|
seq_end = key.size(1)
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||||
|
@ -143,14 +139,11 @@ class QWenAttention(nn.Module):
|
||||||
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
||||||
else:
|
else:
|
||||||
attention_mask = causal_mask
|
attention_mask = causal_mask
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
||||||
|
|
||||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
attn_output = self.c_proj(context_layer)
|
attn_output = self.c_proj(context_layer)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, present)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@ -184,7 +177,6 @@ class QWenBlock(nn.Module):
|
||||||
hidden_size,
|
hidden_size,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = QWenMLP(config)
|
self.mlp = QWenMLP(config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -202,22 +194,16 @@ class QWenBlock(nn.Module):
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
|
|
||||||
outputs = attn_outputs[1:]
|
outputs = attn_outputs[1:]
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
layernorm_input = attn_output + residual
|
layernorm_input = attn_output + residual
|
||||||
|
|
||||||
layernorm_output = self.ln_2(layernorm_input)
|
layernorm_output = self.ln_2(layernorm_input)
|
||||||
|
|
||||||
residual = layernorm_input
|
residual = layernorm_input
|
||||||
mlp_output = self.mlp(layernorm_output)
|
mlp_output = self.mlp(layernorm_output)
|
||||||
hidden_states = residual + mlp_output
|
hidden_states = residual + mlp_output
|
||||||
|
|
||||||
outputs = (hidden_states,) + outputs
|
outputs = (hidden_states,) + outputs
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@ -732,8 +718,6 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||||
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
emb = rearrange(emb, "n d -> 1 n 1 d")
|
emb = rearrange(emb, "n d -> 1 n 1 d")
|
||||||
|
|
||||||
cos, sin = emb.cos(), emb.sin()
|
cos, sin = emb.cos(), emb.sin()
|
||||||
|
@ -746,8 +730,6 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def _rotate_half(x):
|
def _rotate_half(x):
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||||
x1, x2 = x.unbind(dim=-2)
|
x1, x2 = x.unbind(dim=-2)
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
Loading…
Reference in New Issue