Refine model of qwen.

This commit is contained in:
Colin 2024-01-07 22:49:21 +08:00
parent 94ecf0f561
commit 69cb525ab0
1 changed files with 3 additions and 21 deletions

View File

@ -4,16 +4,13 @@
# LICENSE file in the root directory of this source tree.
import copy
import importlib
import math
import inspect
import pathlib
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import warnings
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
@ -30,6 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from torch import nn
from einops import rearrange
from configuration_qwen import QWenConfig
from qwen_generation_utils import (
@ -63,11 +61,9 @@ class QWenAttention(nn.Module):
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
self.use_dynamic_ntk = config.use_dynamic_ntk
self.use_logn_attn = config.use_logn_attn
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
@ -120,7 +116,7 @@ class QWenAttention(nn.Module):
present = (key, value)
key_size = key.size(1)
if key_size > self.seq_length and self.use_logn_attn and not self.training:
if key_size > self.seq_length and not self.training:
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
@ -143,14 +139,11 @@ class QWenAttention(nn.Module):
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
else:
attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer)
outputs = (attn_output, present)
return outputs
@ -184,7 +177,6 @@ class QWenBlock(nn.Module):
hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenMLP(config)
def forward(
@ -202,22 +194,16 @@ class QWenBlock(nn.Module):
layer_past=layer_past,
attention_mask=attention_mask,
)
attn_output = attn_outputs[0]
outputs = attn_outputs[1:]
residual = hidden_states
layernorm_input = attn_output + residual
layernorm_output = self.ln_2(layernorm_input)
residual = layernorm_input
mlp_output = self.mlp(layernorm_output)
hidden_states = residual + mlp_output
outputs = (hidden_states,) + outputs
return outputs
@ -732,8 +718,6 @@ class RotaryEmbedding(torch.nn.Module):
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
from einops import rearrange
emb = rearrange(emb, "n d -> 1 n 1 d")
cos, sin = emb.cos(), emb.sin()
@ -746,8 +730,6 @@ class RotaryEmbedding(torch.nn.Module):
def _rotate_half(x):
from einops import rearrange
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)