Refine model of qwen.
This commit is contained in:
		
							parent
							
								
									94ecf0f561
								
							
						
					
					
						commit
						69cb525ab0
					
				| 
						 | 
				
			
			@ -4,16 +4,13 @@
 | 
			
		|||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
import importlib
 | 
			
		||||
import math
 | 
			
		||||
import inspect
 | 
			
		||||
import pathlib
 | 
			
		||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
from torch.nn import CrossEntropyLoss
 | 
			
		||||
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
 | 
			
		||||
| 
						 | 
				
			
			@ -30,6 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
 | 
			
		|||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from torch import nn
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
 | 
			
		||||
from configuration_qwen import QWenConfig
 | 
			
		||||
from qwen_generation_utils import (
 | 
			
		||||
| 
						 | 
				
			
			@ -63,11 +61,9 @@ class QWenAttention(nn.Module):
 | 
			
		|||
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
 | 
			
		||||
 | 
			
		||||
        self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
 | 
			
		||||
 | 
			
		||||
        self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
 | 
			
		||||
 | 
			
		||||
        self.use_dynamic_ntk = config.use_dynamic_ntk
 | 
			
		||||
        self.use_logn_attn = config.use_logn_attn
 | 
			
		||||
 | 
			
		||||
        logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
 | 
			
		||||
        logn_tensor = torch.tensor(logn_list)[None, :, None, None]
 | 
			
		||||
| 
						 | 
				
			
			@ -120,7 +116,7 @@ class QWenAttention(nn.Module):
 | 
			
		|||
        present = (key, value)
 | 
			
		||||
 | 
			
		||||
        key_size = key.size(1)
 | 
			
		||||
        if key_size > self.seq_length and self.use_logn_attn and not self.training:
 | 
			
		||||
        if key_size > self.seq_length and not self.training:
 | 
			
		||||
            seq_start = key.size(1) - query.size(1)
 | 
			
		||||
            seq_end = key.size(1)
 | 
			
		||||
            logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
 | 
			
		||||
| 
						 | 
				
			
			@ -143,14 +139,11 @@ class QWenAttention(nn.Module):
 | 
			
		|||
                attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = causal_mask
 | 
			
		||||
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
 | 
			
		||||
 | 
			
		||||
        attn_output = self.c_proj(context_layer)
 | 
			
		||||
 | 
			
		||||
        outputs = (attn_output, present)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -184,7 +177,6 @@ class QWenBlock(nn.Module):
 | 
			
		|||
            hidden_size,
 | 
			
		||||
            eps=config.layer_norm_epsilon,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.mlp = QWenMLP(config)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
| 
						 | 
				
			
			@ -202,22 +194,16 @@ class QWenBlock(nn.Module):
 | 
			
		|||
            layer_past=layer_past,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        attn_output = attn_outputs[0]
 | 
			
		||||
 | 
			
		||||
        outputs = attn_outputs[1:]
 | 
			
		||||
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        layernorm_input = attn_output + residual
 | 
			
		||||
 | 
			
		||||
        layernorm_output = self.ln_2(layernorm_input)
 | 
			
		||||
 | 
			
		||||
        residual = layernorm_input
 | 
			
		||||
        mlp_output = self.mlp(layernorm_output)
 | 
			
		||||
        hidden_states = residual + mlp_output
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states,) + outputs
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -732,8 +718,6 @@ class RotaryEmbedding(torch.nn.Module):
 | 
			
		|||
            freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
 | 
			
		||||
 | 
			
		||||
            emb = torch.cat((freqs, freqs), dim=-1)
 | 
			
		||||
            from einops import rearrange
 | 
			
		||||
 | 
			
		||||
            emb = rearrange(emb, "n d -> 1 n 1 d")
 | 
			
		||||
 | 
			
		||||
            cos, sin = emb.cos(), emb.sin()
 | 
			
		||||
| 
						 | 
				
			
			@ -746,8 +730,6 @@ class RotaryEmbedding(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _rotate_half(x):
 | 
			
		||||
    from einops import rearrange
 | 
			
		||||
 | 
			
		||||
    x = rearrange(x, "... (j d) -> ... j d", j=2)
 | 
			
		||||
    x1, x2 = x.unbind(dim=-2)
 | 
			
		||||
    return torch.cat((-x2, x1), dim=-1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue