Refine model of qwen.
This commit is contained in:
parent
94ecf0f561
commit
69cb525ab0
|
@ -4,16 +4,13 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import importlib
|
||||
import math
|
||||
import inspect
|
||||
import pathlib
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import warnings
|
||||
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
|
||||
|
@ -30,6 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|||
from transformers.utils import logging
|
||||
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
|
||||
from configuration_qwen import QWenConfig
|
||||
from qwen_generation_utils import (
|
||||
|
@ -63,11 +61,9 @@ class QWenAttention(nn.Module):
|
|||
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
||||
|
||||
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
|
||||
|
||||
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
|
||||
|
||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||
self.use_logn_attn = config.use_logn_attn
|
||||
|
||||
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
|
||||
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
||||
|
@ -120,7 +116,7 @@ class QWenAttention(nn.Module):
|
|||
present = (key, value)
|
||||
|
||||
key_size = key.size(1)
|
||||
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
||||
if key_size > self.seq_length and not self.training:
|
||||
seq_start = key.size(1) - query.size(1)
|
||||
seq_end = key.size(1)
|
||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||
|
@ -143,14 +139,11 @@ class QWenAttention(nn.Module):
|
|||
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
||||
else:
|
||||
attention_mask = causal_mask
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
||||
|
||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = self.c_proj(context_layer)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -184,7 +177,6 @@ class QWenBlock(nn.Module):
|
|||
hidden_size,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
self.mlp = QWenMLP(config)
|
||||
|
||||
def forward(
|
||||
|
@ -202,22 +194,16 @@ class QWenBlock(nn.Module):
|
|||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_outputs[0]
|
||||
|
||||
outputs = attn_outputs[1:]
|
||||
|
||||
residual = hidden_states
|
||||
layernorm_input = attn_output + residual
|
||||
|
||||
layernorm_output = self.ln_2(layernorm_input)
|
||||
|
||||
residual = layernorm_input
|
||||
mlp_output = self.mlp(layernorm_output)
|
||||
hidden_states = residual + mlp_output
|
||||
|
||||
outputs = (hidden_states,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -732,8 +718,6 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
from einops import rearrange
|
||||
|
||||
emb = rearrange(emb, "n d -> 1 n 1 d")
|
||||
|
||||
cos, sin = emb.cos(), emb.sin()
|
||||
|
@ -746,8 +730,6 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
|
||||
|
||||
def _rotate_half(x):
|
||||
from einops import rearrange
|
||||
|
||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||
x1, x2 = x.unbind(dim=-2)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
|
Loading…
Reference in New Issue