2024-01-03 20:26:26 +08:00
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import importlib
import math
2024-01-07 16:15:27 +08:00
import inspect
2024-01-03 20:26:26 +08:00
import pathlib
from typing import TYPE_CHECKING , Optional , Tuple , Union , Callable , List , Any , Generator
import torch
import torch . nn . functional as F
import torch . utils . checkpoint
import warnings
from torch . nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer , GenerationConfig , StoppingCriteriaList
from transformers . generation . logits_process import LogitsProcessorList
if TYPE_CHECKING :
from transformers . generation . streamers import BaseStreamer
from transformers . generation . utils import GenerateOutput
from transformers . modeling_outputs import (
BaseModelOutputWithPast ,
CausalLMOutputWithPast ,
)
from transformers . modeling_utils import PreTrainedModel
from transformers . utils import logging
try :
from einops import rearrange
except ImportError :
rearrange = None
from torch import nn
SUPPORT_CUDA = torch . cuda . is_available ( )
SUPPORT_BF16 = SUPPORT_CUDA and torch . cuda . is_bf16_supported ( )
SUPPORT_FP16 = SUPPORT_CUDA and torch . cuda . get_device_capability ( 0 ) [ 0 ] > = 7
SUPPORT_TORCH2 = hasattr ( torch , ' __version__ ' ) and int ( torch . __version__ . split ( " . " ) [ 0 ] ) > = 2
2024-01-03 21:03:27 +08:00
from configuration_qwen import QWenConfig
from qwen_generation_utils import (
2024-01-03 20:26:26 +08:00
HistoryType ,
make_context ,
decode_tokens ,
get_stop_words_ids ,
StopWordsLogitsProcessor ,
)
logger = logging . get_logger ( __name__ )
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = [ " qwen-7b " ]
_ERROR_BAD_CHAT_FORMAT = """ \
We detect you are probably using the pretrained model ( rather than chat model ) for chatting , since the chat_format in generation_config is not " chatml " .
If you are directly using the model downloaded from Huggingface , please make sure you are using our " Qwen/Qwen-7B-Chat " Huggingface model ( rather than " Qwen/Qwen-7B " ) when you call model . chat ( ) .
我们检测到您可能在使用预训练模型 ( 而非chat模型 ) 进行多轮chat , 因为您当前在generation_config指定的chat_format , 并未设置为我们在对话中所支持的 " chatml " 格式 。
如果您在直接使用我们从Huggingface提供的模型 , 请确保您在调用model . chat ( ) 时 , 使用的是 " Qwen/Qwen-7B-Chat " 模型 ( 而非 " Qwen/Qwen-7B " 预训练模型 ) 。
"""
_SENTINEL = object ( )
_ERROR_STREAM_IN_CHAT = """ \
Pass argument ` stream ` to model . chat ( ) is buggy , deprecated , and marked for removal . Please use model . chat_stream ( . . . ) instead of model . chat ( . . . , stream = True ) .
向model . chat ( ) 传入参数stream的用法可能存在Bug , 该用法已被废弃 , 将在未来被移除 。 请使用model . chat_stream ( . . . ) 代替model . chat ( . . . , stream = True ) 。
"""
apply_rotary_emb_func = None
rms_norm = None
def quantize_cache_v ( fdata , bits , qmax , qmin ) :
# b, s, head, h-dim->b, head, s, h-dim
qtype = torch . uint8
device = fdata . device
shape = fdata . shape
fdata_cal = torch . flatten ( fdata , 2 )
fmax = torch . amax ( fdata_cal , dim = - 1 , keepdim = True )
fmin = torch . amin ( fdata_cal , dim = - 1 , keepdim = True )
# Compute params
if qmax . device != fmax . device :
qmax = qmax . to ( device )
qmin = qmin . to ( device )
scale = ( fmax - fmin ) / ( qmax - qmin )
zero = qmin - fmin / scale
scale = scale . unsqueeze ( - 1 ) . repeat ( 1 , 1 , shape [ 2 ] , 1 ) . contiguous ( )
zero = zero . unsqueeze ( - 1 ) . repeat ( 1 , 1 , shape [ 2 ] , 1 ) . contiguous ( )
# Quantize
res_data = fdata / scale + zero
qdata = torch . clamp ( res_data , qmin , qmax ) . to ( qtype )
return qdata . contiguous ( ) , scale , zero
def dequantize_cache_torch ( qdata , scale , zero ) :
data = scale * ( qdata - zero )
return data
class QWenAttention ( nn . Module ) :
def __init__ ( self , config ) :
super ( ) . __init__ ( )
self . register_buffer ( " masked_bias " , torch . tensor ( - 1e4 ) , persistent = False )
self . seq_length = config . seq_length
self . hidden_size = config . hidden_size
self . split_size = config . hidden_size
self . num_heads = config . num_attention_heads
self . head_dim = self . hidden_size / / self . num_heads
self . scale_attn_weights = True
self . projection_size = config . kv_channels * config . num_attention_heads
assert self . projection_size % config . num_attention_heads == 0
self . hidden_size_per_attention_head = (
self . projection_size / / config . num_attention_heads
)
self . c_attn = nn . Linear ( config . hidden_size , 3 * self . projection_size )
self . c_proj = nn . Linear (
config . hidden_size , self . projection_size , bias = not config . no_bias
)
self . is_fp32 = not ( config . bf16 or config . fp16 )
2024-01-07 16:22:41 +08:00
2024-01-03 20:26:26 +08:00
self . bf16 = config . bf16
self . use_dynamic_ntk = config . use_dynamic_ntk
self . use_logn_attn = config . use_logn_attn
logn_list = [
math . log ( i , self . seq_length ) if i > self . seq_length else 1
for i in range ( 1 , 32768 )
]
logn_tensor = torch . tensor ( logn_list ) [ None , : , None , None ]
self . register_buffer ( " logn_tensor " , logn_tensor , persistent = False )
self . attn_dropout = nn . Dropout ( config . attn_dropout_prob )
self . softmax_in_fp32 = config . softmax_in_fp32 if hasattr ( config , ' softmax_in_fp32 ' ) else False
self . use_cache_quantization = config . use_cache_quantization if hasattr ( config , ' use_cache_quantization ' ) else False
self . use_cache_kernel = config . use_cache_kernel if hasattr ( config , ' use_cache_kernel ' ) else False
cache_dtype = torch . float
if self . bf16 :
cache_dtype = torch . bfloat16
elif config . fp16 :
cache_dtype = torch . float16
self . cache_qmax = torch . tensor ( torch . iinfo ( torch . uint8 ) . max , dtype = cache_dtype )
self . cache_qmin = torch . tensor ( torch . iinfo ( torch . uint8 ) . min , dtype = cache_dtype )
if config . use_cache_quantization and config . use_cache_kernel :
# pre check if the support files existing
module_root = pathlib . Path ( __file__ ) . parent
src_files = ( " cache_autogptq_cuda_256.cpp " , " cache_autogptq_cuda_kernel_256.cu " )
if any ( not ( module_root / src ) . is_file ( ) for src in src_files ) :
warnings . warn ( " KV cache kernel source files (.cpp and .cu) not found. " )
self . cache_kernels = None
else :
try :
from . cpp_kernels import cache_autogptq_cuda_256
self . cache_kernels = cache_autogptq_cuda_256
except ImportError :
warnings . warn ( " Failed to import KV cache kernels. " )
self . cache_kernels = None
def _attn ( self , query , key , value , causal_mask = None , attention_mask = None , head_mask = None ) :
device = query . device
if self . use_cache_quantization :
qk , qk_scale , qk_zero = key
if self . use_cache_kernel and self . cache_kernels is not None :
shape = query . shape [ : - 1 ] + ( qk . shape [ - 2 ] , )
attn_weights = torch . zeros ( shape , dtype = torch . float16 , device = device )
self . cache_kernels . vecquant8matmul_batched_faster_old (
query . contiguous ( ) if query . dtype == torch . float16 else query . to ( torch . float16 ) . contiguous ( ) ,
qk . transpose ( - 1 , - 2 ) . contiguous ( ) ,
attn_weights ,
qk_scale . contiguous ( ) if qk_scale . dtype == torch . float16 else qk_scale . to ( torch . float16 ) . contiguous ( ) ,
qk_zero . contiguous ( ) if qk_zero . dtype == torch . float16 else qk_zero . to ( torch . float16 ) . contiguous ( ) )
# attn_weights = attn_weights.to(query.dtype).contiguous()
else :
key = dequantize_cache_torch ( qk , qk_scale , qk_zero )
attn_weights = torch . matmul ( query , key . transpose ( - 1 , - 2 ) )
else :
attn_weights = torch . matmul ( query , key . transpose ( - 1 , - 2 ) )
if self . scale_attn_weights :
if self . use_cache_quantization :
size_temp = value [ 0 ] . size ( - 1 )
else :
size_temp = value . size ( - 1 )
attn_weights = attn_weights / ( size_temp * * 0.5 )
mask_value = torch . finfo ( attn_weights . dtype ) . min
if causal_mask is not None :
attn_weights = torch . where (
causal_mask , attn_weights . to ( attn_weights . dtype ) , mask_value
)
if attention_mask is not None :
attn_weights = attn_weights + attention_mask
if self . softmax_in_fp32 :
attn_weights = nn . functional . softmax ( attn_weights . float ( ) , dim = - 1 )
else :
attn_weights = nn . functional . softmax ( attn_weights , dim = - 1 )
attn_weights = attn_weights . type ( query . dtype )
attn_weights = self . attn_dropout ( attn_weights )
if head_mask is not None :
attn_weights = attn_weights * head_mask
if self . use_cache_quantization :
qv , qv_scale , qv_zero = value
if self . use_cache_kernel and self . cache_kernels is not None :
shape = attn_weights . shape [ : - 1 ] + ( query . shape [ - 1 ] , )
attn_output = torch . zeros ( shape , dtype = torch . float16 , device = device )
self . cache_kernels . vecquant8matmul_batched_column_compression_faster_old (
attn_weights . contiguous ( ) if attn_weights . dtype == torch . float16 else attn_weights . to ( torch . float16 ) . contiguous ( ) ,
qv . contiguous ( ) , # dtype: int32
attn_output ,
qv_scale . contiguous ( ) if qv_scale . dtype == torch . float16 else qv_scale . to ( torch . float16 ) . contiguous ( ) ,
qv_zero . contiguous ( ) if qv_zero . dtype == torch . float16 else qv_zero . to ( torch . float16 ) . contiguous ( ) )
if attn_output . dtype != query . dtype :
attn_output = attn_output . to ( query . dtype )
attn_weights = attn_weights . to ( query . dtype )
else :
value = dequantize_cache_torch ( qv , qv_scale , qv_zero )
attn_output = torch . matmul ( attn_weights , value )
else :
attn_output = torch . matmul ( attn_weights , value )
attn_output = attn_output . transpose ( 1 , 2 )
return attn_output , attn_weights
def _split_heads ( self , tensor , num_heads , attn_head_size ) :
new_shape = tensor . size ( ) [ : - 1 ] + ( num_heads , attn_head_size )
tensor = tensor . view ( new_shape )
return tensor
def _merge_heads ( self , tensor , num_heads , attn_head_size ) :
tensor = tensor . contiguous ( )
new_shape = tensor . size ( ) [ : - 2 ] + ( num_heads * attn_head_size , )
return tensor . view ( new_shape )
def forward (
self ,
hidden_states : Optional [ Tuple [ torch . FloatTensor ] ] ,
rotary_pos_emb_list : Optional [ List [ List [ torch . Tensor ] ] ] = None ,
layer_past : Optional [ Tuple [ torch . Tensor ] ] = None ,
attention_mask : Optional [ torch . FloatTensor ] = None ,
head_mask : Optional [ torch . FloatTensor ] = None ,
encoder_hidden_states : Optional [ torch . Tensor ] = None ,
encoder_attention_mask : Optional [ torch . FloatTensor ] = None ,
output_attentions : Optional [ bool ] = False ,
use_cache : Optional [ bool ] = False ,
) :
mixed_x_layer = self . c_attn ( hidden_states )
query , key , value = mixed_x_layer . split ( self . split_size , dim = 2 )
query = self . _split_heads ( query , self . num_heads , self . head_dim )
key = self . _split_heads ( key , self . num_heads , self . head_dim )
value = self . _split_heads ( value , self . num_heads , self . head_dim )
if rotary_pos_emb_list is not None :
cur_len = query . shape [ 1 ]
if len ( rotary_pos_emb_list ) == 1 :
rotary_pos_emb = rotary_pos_emb_list [ 0 ]
rotary_pos_emb = [ i [ : , - cur_len : , : , : ] for i in rotary_pos_emb ]
rotary_pos_emb = ( rotary_pos_emb , ) * 2
q_pos_emb , k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb ( query , q_pos_emb )
key = apply_rotary_pos_emb ( key , k_pos_emb )
else :
query_list = [ ]
key_list = [ ]
for i , rotary_pos_emb in enumerate ( rotary_pos_emb_list ) :
rotary_pos_emb = [ i [ : , - cur_len : , : , : ] for i in rotary_pos_emb ]
rotary_pos_emb = ( rotary_pos_emb , ) * 2
q_pos_emb , k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query_list + = [ apply_rotary_pos_emb ( query [ i : i + 1 , : , : ] , q_pos_emb ) ]
key_list + = [ apply_rotary_pos_emb ( key [ i : i + 1 , : , : ] , k_pos_emb ) ]
query = torch . cat ( query_list , dim = 0 )
key = torch . cat ( key_list , dim = 0 )
if self . use_cache_quantization :
key = quantize_cache_v ( key . permute ( 0 , 2 , 1 , 3 ) ,
bits = 8 ,
qmin = self . cache_qmin ,
qmax = self . cache_qmax )
value = quantize_cache_v ( value . permute ( 0 , 2 , 1 , 3 ) ,
bits = 8 ,
qmin = self . cache_qmin ,
qmax = self . cache_qmax )
if layer_past is not None :
past_key , past_value = layer_past [ 0 ] , layer_past [ 1 ]
if self . use_cache_quantization :
# use_cache_quantization:
# present=((q_key,key_scale,key_zero_point),
# (q_value,value_scale,value_zero_point))
key = ( torch . cat ( ( past_key [ 0 ] , key [ 0 ] ) , dim = 2 ) ,
torch . cat ( ( past_key [ 1 ] , key [ 1 ] ) , dim = 2 ) ,
torch . cat ( ( past_key [ 2 ] , key [ 2 ] ) , dim = 2 ) )
value = ( torch . cat ( ( past_value [ 0 ] , value [ 0 ] ) , dim = 2 ) ,
torch . cat ( ( past_value [ 1 ] , value [ 1 ] ) , dim = 2 ) ,
torch . cat ( ( past_value [ 2 ] , value [ 2 ] ) , dim = 2 ) )
else :
# not use_cache_quantization:
# present=(key,value)
key = torch . cat ( ( past_key , key ) , dim = 1 )
value = torch . cat ( ( past_value , value ) , dim = 1 )
if use_cache :
present = ( key , value )
else :
present = None
key_size = key [ 0 ] . size ( 2 ) if self . use_cache_quantization else key . size ( 1 )
if key_size > self . seq_length and self . use_logn_attn and not self . training :
if self . use_cache_quantization :
seq_start = key [ 0 ] . size ( 2 ) - query . size ( 1 )
seq_end = key [ 0 ] . size ( 2 )
else :
seq_start = key . size ( 1 ) - query . size ( 1 )
seq_end = key . size ( 1 )
logn_tensor = self . logn_tensor [ : , seq_start : seq_end , : , : ] . type_as ( query )
query = query * logn_tensor . expand_as ( query )
2024-01-07 16:22:41 +08:00
key_size = key [ 0 ] . size ( 2 ) if self . use_cache_quantization else key . size ( 1 )
if query . size ( 1 ) == key_size :
causal_mask = torch . tril (
torch . ones ( ( key_size , key_size ) , dtype = torch . bool , device = query . device )
) . view ( 1 , 1 , key_size , key_size )
2024-01-03 20:26:26 +08:00
else :
2024-01-07 16:22:41 +08:00
causal_mask = None
query = query . permute ( 0 , 2 , 1 , 3 )
if not self . use_cache_quantization :
key = key . permute ( 0 , 2 , 1 , 3 )
value = value . permute ( 0 , 2 , 1 , 3 )
if not self . use_cache_quantization and SUPPORT_TORCH2 :
if attention_mask is not None :
attention_mask = attention_mask . expand (
- 1 , - 1 , causal_mask . size ( 2 ) , - 1
2024-01-03 20:26:26 +08:00
)
2024-01-07 16:22:41 +08:00
if causal_mask is not None :
attention_mask = attention_mask . masked_fill ( ~ causal_mask , torch . finfo ( query . dtype ) . min )
else :
attention_mask = causal_mask
attn_output = F . scaled_dot_product_attention (
query , key , value , attn_mask = attention_mask
) . transpose ( 1 , 2 )
attn_weight = None
else :
attn_output , attn_weight = self . _attn (
query , key , value , causal_mask , attention_mask , head_mask
)
2024-01-03 20:26:26 +08:00
context_layer = self . _merge_heads (
attn_output , self . num_heads , self . head_dim
)
attn_output = self . c_proj ( context_layer )
outputs = ( attn_output , present )
if output_attentions :
2024-01-07 16:22:41 +08:00
if not self . use_cache_quantization and SUPPORT_TORCH2 :
2024-01-03 20:26:26 +08:00
raise ValueError ( " Cannot output attentions while using scaled_dot_product_attention " )
else :
outputs + = ( attn_weight , )
return outputs
class QWenMLP ( nn . Module ) :
def __init__ ( self , config ) :
super ( ) . __init__ ( )
self . w1 = nn . Linear (
config . hidden_size , config . intermediate_size / / 2 , bias = not config . no_bias
)
self . w2 = nn . Linear (
config . hidden_size , config . intermediate_size / / 2 , bias = not config . no_bias
)
ff_dim_in = config . intermediate_size / / 2
self . c_proj = nn . Linear ( ff_dim_in , config . hidden_size , bias = not config . no_bias )
def forward ( self , hidden_states ) :
a1 = self . w1 ( hidden_states )
a2 = self . w2 ( hidden_states )
intermediate_parallel = a1 * F . silu ( a2 )
output = self . c_proj ( intermediate_parallel )
return output
class QWenBlock ( nn . Module ) :
def __init__ ( self , config ) :
super ( ) . __init__ ( )
hidden_size = config . hidden_size
self . bf16 = config . bf16
self . ln_1 = RMSNorm (
hidden_size ,
eps = config . layer_norm_epsilon ,
)
self . attn = QWenAttention ( config )
self . ln_2 = RMSNorm (
hidden_size ,
eps = config . layer_norm_epsilon ,
)
self . mlp = QWenMLP ( config )
def forward (
self ,
hidden_states : Optional [ Tuple [ torch . FloatTensor ] ] ,
rotary_pos_emb_list : Optional [ List [ List [ torch . Tensor ] ] ] = None ,
layer_past : Optional [ Tuple [ torch . Tensor ] ] = None ,
attention_mask : Optional [ torch . FloatTensor ] = None ,
head_mask : Optional [ torch . FloatTensor ] = None ,
encoder_hidden_states : Optional [ torch . Tensor ] = None ,
encoder_attention_mask : Optional [ torch . FloatTensor ] = None ,
use_cache : Optional [ bool ] = False ,
output_attentions : Optional [ bool ] = False ,
) :
layernorm_output = self . ln_1 ( hidden_states )
attn_outputs = self . attn (
layernorm_output ,
rotary_pos_emb_list ,
layer_past = layer_past ,
attention_mask = attention_mask ,
head_mask = head_mask ,
use_cache = use_cache ,
output_attentions = output_attentions ,
)
attn_output = attn_outputs [ 0 ]
outputs = attn_outputs [ 1 : ]
residual = hidden_states
layernorm_input = attn_output + residual
layernorm_output = self . ln_2 ( layernorm_input )
residual = layernorm_input
mlp_output = self . mlp ( layernorm_output )
hidden_states = residual + mlp_output
if use_cache :
outputs = ( hidden_states , ) + outputs
else :
outputs = ( hidden_states , ) + outputs [ 1 : ]
return outputs
class QWenPreTrainedModel ( PreTrainedModel ) :
config_class = QWenConfig
base_model_prefix = " transformer "
is_parallelizable = False
supports_gradient_checkpointing = True
_no_split_modules = [ " QWenBlock " ]
_skip_keys_device_placement = " past_key_values "
def __init__ ( self , * inputs , * * kwargs ) :
super ( ) . __init__ ( * inputs , * * kwargs )
def _init_weights ( self , module ) :
""" Initialize the weights. """
if isinstance ( module , nn . Linear ) :
module . weight . data . normal_ ( mean = 0.0 , std = self . config . initializer_range )
if module . bias is not None :
module . bias . data . zero_ ( )
elif isinstance ( module , nn . Embedding ) :
module . weight . data . normal_ ( mean = 0.0 , std = self . config . initializer_range )
if module . padding_idx is not None :
module . weight . data [ module . padding_idx ] . zero_ ( )
elif isinstance ( module , RMSNorm ) :
module . weight . data . fill_ ( 1.0 )
for name , p in module . named_parameters ( ) :
if name == " c_proj.weight " :
p . data . normal_ (
mean = 0.0 ,
std = (
self . config . initializer_range
/ math . sqrt ( 2 * self . config . num_hidden_layers )
) ,
)
def _set_gradient_checkpointing ( self , module , value = False ) :
if isinstance ( module , QWenModel ) :
module . gradient_checkpointing = value
class QWenModel ( QWenPreTrainedModel ) :
_keys_to_ignore_on_load_missing = [ " attn.masked_bias " ]
def __init__ ( self , config ) :
super ( ) . __init__ ( config )
self . vocab_size = config . vocab_size
self . num_hidden_layers = config . num_hidden_layers
self . embed_dim = config . hidden_size
self . use_cache_quantization = self . config . use_cache_quantization if hasattr ( self . config , ' use_cache_quantization ' ) else False
self . gradient_checkpointing = False
self . use_dynamic_ntk = config . use_dynamic_ntk
self . seq_length = config . seq_length
self . wte = nn . Embedding ( self . vocab_size , self . embed_dim )
self . drop = nn . Dropout ( config . emb_dropout_prob )
if config . rotary_pct == 1.0 :
self . rotary_ndims = None
else :
assert config . rotary_pct < 1
self . rotary_ndims = int (
config . kv_channels * config . rotary_pct
)
dim = (
self . rotary_ndims
if self . rotary_ndims is not None
else config . kv_channels
)
self . rotary_emb = RotaryEmbedding ( dim , base = config . rotary_emb_base )
self . is_fp32 = not ( config . bf16 or config . fp16 )
self . h = nn . ModuleList (
[
QWenBlock (
config
)
for i in range ( config . num_hidden_layers )
]
)
self . ln_f = RMSNorm (
self . embed_dim ,
eps = config . layer_norm_epsilon ,
)
self . post_init ( )
def get_input_embeddings ( self ) :
return self . wte
def set_input_embeddings ( self , new_embeddings ) :
self . wte = new_embeddings
def get_ntk_alpha ( self , true_seq_len ) :
context_value = math . log ( true_seq_len / self . seq_length , 2 ) + 1
ntk_alpha = 2 * * math . ceil ( context_value ) - 1
ntk_alpha = max ( ntk_alpha , 1 )
return ntk_alpha
def forward (
self ,
input_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ Tuple [ Tuple [ torch . Tensor ] ] ] = None ,
attention_mask : Optional [ torch . FloatTensor ] = None ,
token_type_ids : Optional [ torch . LongTensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
head_mask : Optional [ torch . FloatTensor ] = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
encoder_hidden_states : Optional [ torch . Tensor ] = None ,
encoder_attention_mask : Optional [ torch . FloatTensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) :
output_attentions = (
output_attentions
if output_attentions is not None
else self . config . output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self . config . output_hidden_states
)
use_cache = use_cache if use_cache is not None else self . config . use_cache
return_dict = (
return_dict if return_dict is not None else self . config . use_return_dict
)
if input_ids is not None and inputs_embeds is not None :
raise ValueError (
" You cannot specify both input_ids and inputs_embeds at the same time "
)
elif input_ids is not None :
input_shape = input_ids . size ( )
input_ids = input_ids . view ( - 1 , input_shape [ - 1 ] )
batch_size = input_ids . shape [ 0 ]
elif inputs_embeds is not None :
input_shape = inputs_embeds . size ( ) [ : - 1 ]
batch_size = inputs_embeds . shape [ 0 ]
else :
raise ValueError ( " You have to specify either input_ids or inputs_embeds " )
device = input_ids . device if input_ids is not None else inputs_embeds . device
if token_type_ids is not None :
token_type_ids = token_type_ids . view ( - 1 , input_shape [ - 1 ] )
if position_ids is not None :
position_ids = position_ids . view ( - 1 , input_shape [ - 1 ] )
if past_key_values is None :
past_length = 0
past_key_values = tuple ( [ None ] * len ( self . h ) )
else :
if self . use_cache_quantization :
past_length = past_key_values [ 0 ] [ 0 ] [ 0 ] . size ( 2 )
else :
past_length = past_key_values [ 0 ] [ 0 ] . size ( - 2 )
if position_ids is None :
position_ids = torch . arange (
past_length ,
input_shape [ - 1 ] + past_length ,
dtype = torch . long ,
device = device ,
)
position_ids = position_ids . unsqueeze ( 0 ) . view ( - 1 , input_shape [ - 1 ] )
if attention_mask is not None :
if batch_size < = 0 :
raise ValueError ( " batch_size has to be defined and > 0 " )
attention_mask = attention_mask . view ( batch_size , - 1 )
attention_mask = attention_mask [ : , None , None , : ]
attention_mask = attention_mask . to ( dtype = self . dtype )
attention_mask = ( 1.0 - attention_mask ) * torch . finfo ( self . dtype ) . min
encoder_attention_mask = None
head_mask = self . get_head_mask ( head_mask , self . config . num_hidden_layers )
if inputs_embeds is None :
inputs_embeds = self . wte ( input_ids )
hidden_states = inputs_embeds
kv_seq_len = hidden_states . size ( ) [ 1 ]
if past_key_values [ 0 ] is not None :
# past key values[0][0] shape: bs * seq_len * head_num * dim
if self . use_cache_quantization :
kv_seq_len + = past_key_values [ 0 ] [ 0 ] [ 0 ] . shape [ 2 ]
else :
kv_seq_len + = past_key_values [ 0 ] [ 0 ] . shape [ 1 ]
if self . training or not self . use_dynamic_ntk :
ntk_alpha_list = [ 1.0 ]
elif kv_seq_len != hidden_states . size ( ) [ 1 ] :
ntk_alpha_list = self . rotary_emb . _ntk_alpha_cached_list
else :
ntk_alpha_list = [ ]
if attention_mask is not None and kv_seq_len > self . seq_length :
true_seq_lens = attention_mask . squeeze ( 1 ) . squeeze ( 1 ) . eq ( 0 ) . sum ( dim = - 1 , dtype = torch . int32 )
for i in range ( hidden_states . size ( ) [ 0 ] ) :
true_seq_len = true_seq_lens [ i ] . item ( )
ntk_alpha = self . get_ntk_alpha ( true_seq_len )
ntk_alpha_list . append ( ntk_alpha )
else :
ntk_alpha = self . get_ntk_alpha ( kv_seq_len )
ntk_alpha_list . append ( ntk_alpha )
self . rotary_emb . _ntk_alpha_cached_list = ntk_alpha_list
rotary_pos_emb_list = [
self . rotary_emb ( kv_seq_len , ntk_alpha = ntk_alpha ) for ntk_alpha in ntk_alpha_list
]
hidden_states = self . drop ( hidden_states )
output_shape = input_shape + ( hidden_states . size ( - 1 ) , )
if self . gradient_checkpointing and self . training :
if use_cache :
logger . warning_once (
" `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... "
)
use_cache = False
presents = ( ) if use_cache else None
all_self_attentions = ( ) if output_attentions else None
all_hidden_states = ( ) if output_hidden_states else None
for i , ( block , layer_past ) in enumerate ( zip ( self . h , past_key_values ) ) :
if output_hidden_states :
all_hidden_states = all_hidden_states + ( hidden_states , )
if self . gradient_checkpointing and self . training :
def create_custom_forward ( module ) :
def custom_forward ( * inputs ) :
# None for past_key_value
return module ( * inputs , use_cache , output_attentions )
return custom_forward
outputs = torch . utils . checkpoint . checkpoint (
create_custom_forward ( block ) ,
hidden_states ,
rotary_pos_emb_list ,
None ,
attention_mask ,
head_mask [ i ] ,
encoder_hidden_states ,
encoder_attention_mask ,
)
else :
outputs = block (
hidden_states ,
layer_past = layer_past ,
rotary_pos_emb_list = rotary_pos_emb_list ,
attention_mask = attention_mask ,
head_mask = head_mask [ i ] ,
encoder_hidden_states = encoder_hidden_states ,
encoder_attention_mask = encoder_attention_mask ,
use_cache = use_cache ,
output_attentions = output_attentions ,
)
hidden_states = outputs [ 0 ]
if use_cache is True :
presents = presents + ( outputs [ 1 ] , )
if output_attentions :
all_self_attentions = all_self_attentions + ( outputs [ 2 if use_cache else 1 ] , )
hidden_states = self . ln_f ( hidden_states )
hidden_states = hidden_states . view ( output_shape )
# Add last hidden state
if output_hidden_states :
all_hidden_states = all_hidden_states + ( hidden_states , )
if not return_dict :
return tuple (
v for v in [ hidden_states , presents , all_hidden_states ] if v is not None
)
return BaseModelOutputWithPast (
last_hidden_state = hidden_states ,
past_key_values = presents ,
hidden_states = all_hidden_states ,
attentions = all_self_attentions ,
)
class QWenLMHeadModel ( QWenPreTrainedModel ) :
_keys_to_ignore_on_load_missing = [ r " h \ . \ d+ \ .attn \ .rotary_emb \ .inv_freq " ]
_keys_to_ignore_on_load_unexpected = [ r " h \ . \ d+ \ .attn \ .masked_bias " ]
def __init__ ( self , config ) :
super ( ) . __init__ ( config )
assert (
config . bf16 + config . fp16 + config . fp32 < = 1
) , " Only one of \" bf16 \" , \" fp16 \" , \" fp32 \" can be true "
autoset_precision = config . bf16 + config . fp16 + config . fp32 == 0
if autoset_precision :
if SUPPORT_BF16 :
logger . warn (
" The model is automatically converting to bf16 for faster inference. "
" If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \" AutoModelForCausalLM.from_pretrained \" . "
)
config . bf16 = True
elif SUPPORT_FP16 :
logger . warn (
" The model is automatically converting to fp16 for faster inference. "
" If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \" AutoModelForCausalLM.from_pretrained \" . "
)
config . fp16 = True
else :
config . fp32 = True
if config . bf16 and SUPPORT_CUDA and not SUPPORT_BF16 :
logger . warn ( " Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \" AutoModelForCausalLM.from_pretrained \" . " )
if config . fp16 and SUPPORT_CUDA and not SUPPORT_FP16 :
logger . warn ( " Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster " )
if config . fp32 :
if SUPPORT_BF16 :
logger . warn ( " Your device support faster inference by passing bf16=True in \" AutoModelForCausalLM.from_pretrained \" . " )
elif SUPPORT_FP16 :
logger . warn ( " Your device support faster inference by passing fp16=True in \" AutoModelForCausalLM.from_pretrained \" . " )
self . transformer = QWenModel ( config )
self . lm_head = nn . Linear ( config . hidden_size , config . vocab_size , bias = False )
if config . bf16 :
self . transformer . bfloat16 ( )
self . lm_head . bfloat16 ( )
if config . fp16 :
self . transformer . half ( )
self . lm_head . half ( )
self . post_init ( )
def get_output_embeddings ( self ) :
return self . lm_head
def set_output_embeddings ( self , new_embeddings ) :
self . lm_head = new_embeddings
def prepare_inputs_for_generation (
self , input_ids , past_key_values = None , inputs_embeds = None , * * kwargs
) :
if past_key_values :
input_ids = input_ids [ : , - 1 ] . unsqueeze ( - 1 )
if input_ids . size ( 0 ) == 1 :
attention_mask = None
else :
attention_mask = kwargs . get ( " attention_mask " , None )
if inputs_embeds is not None and past_key_values is None :
model_inputs = { " inputs_embeds " : inputs_embeds }
else :
model_inputs = { " input_ids " : input_ids }
model_inputs . update (
{
" past_key_values " : past_key_values ,
" use_cache " : kwargs . get ( " use_cache " ) ,
" attention_mask " : attention_mask ,
}
)
return model_inputs
def forward (
self ,
input_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ Tuple [ Tuple [ torch . Tensor ] ] ] = None ,
attention_mask : Optional [ torch . FloatTensor ] = None ,
token_type_ids : Optional [ torch . LongTensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
head_mask : Optional [ torch . FloatTensor ] = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
encoder_hidden_states : Optional [ torch . Tensor ] = None ,
encoder_attention_mask : Optional [ torch . FloatTensor ] = None ,
labels : Optional [ torch . LongTensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple , CausalLMOutputWithPast ] :
return_dict = (
return_dict if return_dict is not None else self . config . use_return_dict
)
transformer_outputs = self . transformer (
input_ids ,
past_key_values = past_key_values ,
attention_mask = attention_mask ,
token_type_ids = token_type_ids ,
position_ids = position_ids ,
head_mask = head_mask ,
inputs_embeds = inputs_embeds ,
encoder_hidden_states = encoder_hidden_states ,
encoder_attention_mask = encoder_attention_mask ,
use_cache = use_cache ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
return_dict = return_dict ,
)
hidden_states = transformer_outputs [ 0 ]
lm_logits = self . lm_head ( hidden_states )
loss = None
if labels is not None :
labels = labels . to ( lm_logits . device )
shift_logits = lm_logits [ . . . , : - 1 , : ] . contiguous ( )
shift_labels = labels [ . . . , 1 : ] . contiguous ( )
loss_fct = CrossEntropyLoss ( )
loss = loss_fct (
shift_logits . view ( - 1 , shift_logits . size ( - 1 ) ) , shift_labels . view ( - 1 )
)
2024-01-07 16:15:27 +08:00
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
# shift_logits = lm_logits[..., :-1, :].contiguous()
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
# )
# loss.backward()
2024-01-03 20:26:26 +08:00
if not return_dict :
output = ( lm_logits , ) + transformer_outputs [ 1 : ]
return ( ( loss , ) + output ) if loss is not None else output
return CausalLMOutputWithPast (
loss = loss ,
logits = lm_logits ,
past_key_values = transformer_outputs . past_key_values ,
hidden_states = transformer_outputs . hidden_states ,
attentions = transformer_outputs . attentions ,
)
@staticmethod
def _reorder_cache (
past_key_values : Tuple [ Tuple [ torch . Tensor ] ] , beam_idx : torch . Tensor
) - > Tuple [ Tuple [ torch . Tensor ] ] :
return tuple (
tuple (
past_state . index_select ( 0 , beam_idx . to ( past_state . device ) )
for past_state in layer_past
)
for layer_past in past_key_values
)
def chat (
self ,
tokenizer : PreTrainedTokenizer ,
query : str ,
history : Optional [ HistoryType ] ,
system : str = " You are a helpful assistant. " ,
stream : Optional [ bool ] = _SENTINEL ,
stop_words_ids : Optional [ List [ List [ int ] ] ] = None ,
generation_config : Optional [ GenerationConfig ] = None ,
* * kwargs ,
) - > Tuple [ str , HistoryType ] :
generation_config = generation_config if generation_config is not None else self . generation_config
assert stream is _SENTINEL , _ERROR_STREAM_IN_CHAT
assert generation_config . chat_format == ' chatml ' , _ERROR_BAD_CHAT_FORMAT
if history is None :
history = [ ]
else :
# make a copy of the user's input such that is is left untouched
history = copy . deepcopy ( history )
if stop_words_ids is None :
stop_words_ids = [ ]
max_window_size = kwargs . get ( ' max_window_size ' , None )
if max_window_size is None :
max_window_size = generation_config . max_window_size
raw_text , context_tokens = make_context (
tokenizer ,
query ,
history = history ,
system = system ,
max_window_size = max_window_size ,
chat_format = generation_config . chat_format ,
)
stop_words_ids . extend ( get_stop_words_ids (
generation_config . chat_format , tokenizer
) )
input_ids = torch . tensor ( [ context_tokens ] ) . to ( self . device )
outputs = self . generate (
input_ids ,
stop_words_ids = stop_words_ids ,
return_dict_in_generate = False ,
generation_config = generation_config ,
* * kwargs ,
)
2024-01-07 16:15:27 +08:00
2024-01-03 20:26:26 +08:00
response = decode_tokens (
outputs [ 0 ] ,
tokenizer ,
raw_text_len = len ( raw_text ) ,
context_length = len ( context_tokens ) ,
chat_format = generation_config . chat_format ,
verbose = False ,
errors = ' replace '
)
# as history is a copy of the user inputs,
# we can always return the new turn to the user.
# separating input history and output history also enables the user
# to implement more complex history management
history . append ( ( query , response ) )
return response , history
def chat_stream (
self ,
tokenizer : PreTrainedTokenizer ,
query : str ,
history : Optional [ HistoryType ] ,
system : str = " You are a helpful assistant. " ,
stop_words_ids : Optional [ List [ List [ int ] ] ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
generation_config : Optional [ GenerationConfig ] = None ,
* * kwargs ,
) - > Generator [ str , Any , None ] :
generation_config = generation_config if generation_config is not None else self . generation_config
assert generation_config . chat_format == ' chatml ' , _ERROR_BAD_CHAT_FORMAT
if history is None :
history = [ ]
if stop_words_ids is None :
stop_words_ids = [ ]
max_window_size = kwargs . get ( ' max_window_size ' , None )
if max_window_size is None :
max_window_size = generation_config . max_window_size
raw_text , context_tokens = make_context (
tokenizer ,
query ,
history = history ,
system = system ,
max_window_size = max_window_size ,
chat_format = generation_config . chat_format ,
)
stop_words_ids . extend ( get_stop_words_ids (
generation_config . chat_format , tokenizer
) )
if stop_words_ids is not None :
stop_words_logits_processor = StopWordsLogitsProcessor (
stop_words_ids = stop_words_ids ,
eos_token_id = generation_config . eos_token_id ,
)
if logits_processor is None :
logits_processor = LogitsProcessorList ( [ stop_words_logits_processor ] )
else :
logits_processor . append ( stop_words_logits_processor )
input_ids = torch . tensor ( [ context_tokens ] ) . to ( self . device )
from transformers_stream_generator . main import NewGenerationMixin , StreamGenerationConfig
self . __class__ . generate_stream = NewGenerationMixin . generate
self . __class__ . sample_stream = NewGenerationMixin . sample_stream
stream_config = StreamGenerationConfig ( * * generation_config . to_dict ( ) , do_stream = True )
def stream_generator ( ) :
outputs = [ ]
for token in self . generate_stream (
input_ids ,
return_dict_in_generate = False ,
generation_config = stream_config ,
logits_processor = logits_processor ,
seed = - 1 ,
* * kwargs ) :
outputs . append ( token . item ( ) )
yield tokenizer . decode ( outputs , skip_special_tokens = True , errors = ' ignore ' )
return stream_generator ( )
def generate (
self ,
inputs : Optional [ torch . Tensor ] = None ,
generation_config : Optional [ GenerationConfig ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
prefix_allowed_tokens_fn : Optional [
Callable [ [ int , torch . Tensor ] , List [ int ] ]
] = None ,
synced_gpus : Optional [ bool ] = None ,
assistant_model : Optional [ " PreTrainedModel " ] = None ,
streamer : Optional [ " BaseStreamer " ] = None ,
* * kwargs ,
) - > Union [ GenerateOutput , torch . LongTensor ] :
generation_config = generation_config if generation_config is not None else self . generation_config
# Process stop_words_ids.
stop_words_ids = kwargs . pop ( " stop_words_ids " , None )
if stop_words_ids is None and generation_config is not None :
stop_words_ids = getattr ( generation_config , " stop_words_ids " , None )
if stop_words_ids is None :
stop_words_ids = getattr ( generation_config , " stop_words_ids " , None )
if stop_words_ids is not None :
stop_words_logits_processor = StopWordsLogitsProcessor (
stop_words_ids = stop_words_ids ,
eos_token_id = generation_config . eos_token_id ,
)
if logits_processor is None :
logits_processor = LogitsProcessorList ( [ stop_words_logits_processor ] )
else :
logits_processor . append ( stop_words_logits_processor )
2024-01-07 16:15:27 +08:00
return self . generate_base (
2024-01-03 20:26:26 +08:00
inputs ,
generation_config = generation_config ,
logits_processor = logits_processor ,
stopping_criteria = stopping_criteria ,
prefix_allowed_tokens_fn = prefix_allowed_tokens_fn ,
synced_gpus = synced_gpus ,
assistant_model = assistant_model ,
streamer = streamer ,
* * kwargs ,
)
2024-01-07 16:15:27 +08:00
def generate_base (
self ,
inputs : Optional [ torch . Tensor ] = None ,
generation_config : Optional [ GenerationConfig ] = None ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
prefix_allowed_tokens_fn : Optional [ Callable [ [ int , torch . Tensor ] , List [ int ] ] ] = None ,
synced_gpus : Optional [ bool ] = None ,
assistant_model : Optional [ " PreTrainedModel " ] = None ,
streamer : Optional [ " BaseStreamer " ] = None ,
negative_prompt_ids : Optional [ torch . Tensor ] = None ,
negative_prompt_attention_mask : Optional [ torch . Tensor ] = None ,
* * kwargs ,
) - > Union [ GenerateOutput , torch . LongTensor ] :
if synced_gpus is None :
synced_gpus = False
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self . _validate_model_class ( )
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None :
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if self . generation_config . _from_model_config and self . generation_config . _original_object_hash == hash (
self . generation_config
) :
new_generation_config = GenerationConfig . from_model_config ( self . config )
if new_generation_config != self . generation_config :
warnings . warn (
" You have modified the pretrained model configuration to control generation. This is a "
" deprecated strategy to control generation and will be removed soon, in a future version. "
" Please use and modify the model generation configuration (see "
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration ) "
)
self . generation_config = new_generation_config
generation_config = self . generation_config
generation_config = copy . deepcopy ( generation_config )
model_kwargs = generation_config . update ( * * kwargs ) # All unused kwargs must be model kwargs
generation_config . validate ( )
self . _validate_model_kwargs ( model_kwargs . copy ( ) )
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList ( )
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList ( )
if generation_config . pad_token_id is None and generation_config . eos_token_id is not None :
if model_kwargs . get ( " attention_mask " , None ) is None :
logger . warning (
" The attention mask and the pad token id were not set. As a consequence, you may observe "
" unexpected behavior. Please pass your input ' s `attention_mask` to obtain reliable results. "
)
eos_token_id = generation_config . eos_token_id
if isinstance ( eos_token_id , list ) :
eos_token_id = eos_token_id [ 0 ]
logger . warning ( f " Setting `pad_token_id` to `eos_token_id`: { eos_token_id } for open-end generation. " )
generation_config . pad_token_id = eos_token_id
# 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor , model_input_name , model_kwargs = self . _prepare_model_inputs (
inputs , generation_config . bos_token_id , model_kwargs
)
batch_size = inputs_tensor . shape [ 0 ]
# 4. Define other model kwargs
model_kwargs [ " output_attentions " ] = generation_config . output_attentions
model_kwargs [ " output_hidden_states " ] = generation_config . output_hidden_states
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self . config . is_encoder_decoder and model_input_name == " inputs_embeds " :
model_kwargs [ " use_cache " ] = True
else :
model_kwargs [ " use_cache " ] = generation_config . use_cache
accepts_attention_mask = " attention_mask " in set ( inspect . signature ( self . forward ) . parameters . keys ( ) )
requires_attention_mask = " encoder_outputs " not in model_kwargs
if model_kwargs . get ( " attention_mask " , None ) is None and requires_attention_mask and accepts_attention_mask :
model_kwargs [ " attention_mask " ] = self . _prepare_attention_mask_for_generation (
inputs_tensor , generation_config . pad_token_id , generation_config . eos_token_id
)
# decoder-only models should use left-padding for generation
if not self . config . is_encoder_decoder :
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config . pad_token_id is not None
and len ( inputs_tensor . shape ) == 2
and torch . sum ( inputs_tensor [ : , - 1 ] == generation_config . pad_token_id ) > 0
) :
logger . warning (
" A decoder-only architecture is being used, but right-padding was detected! For correct "
" generation results, please set `padding_side= ' left ' ` when initializing the tokenizer. "
)
if self . config . is_encoder_decoder and " encoder_outputs " not in model_kwargs :
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self . _prepare_encoder_decoder_kwargs_for_generation (
inputs_tensor , model_kwargs , model_input_name
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self . config . is_encoder_decoder :
input_ids , model_kwargs = self . _prepare_decoder_input_ids_for_generation (
batch_size = batch_size ,
model_input_name = model_input_name ,
model_kwargs = model_kwargs ,
decoder_start_token_id = generation_config . decoder_start_token_id ,
bos_token_id = generation_config . bos_token_id ,
device = inputs_tensor . device ,
)
else :
input_ids = inputs_tensor if model_input_name == " input_ids " else model_kwargs . pop ( " input_ids " )
if streamer is not None :
streamer . put ( input_ids . cpu ( ) )
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids . shape [ - 1 ]
has_default_max_length = kwargs . get ( " max_length " ) is None and generation_config . max_length is not None
if generation_config . max_new_tokens is not None :
if not has_default_max_length and generation_config . max_length is not None :
logger . warning (
f " Both `max_new_tokens` (= { generation_config . max_new_tokens } ) and `max_length`(= "
f " { generation_config . max_length } ) seem to have been set. `max_new_tokens` will take precedence. "
" Please refer to the documentation for more information. "
" (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation) "
)
generation_config . max_length = generation_config . max_new_tokens + input_ids_length
self . _validate_generated_length ( generation_config , input_ids_length , has_default_max_length )
# 7. determine generation mode
generation_mode = self . _get_generation_mode ( generation_config , assistant_model )
if streamer is not None and ( generation_config . num_beams > 1 ) :
raise ValueError (
" `streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1. "
)
if self . device . type != input_ids . device . type :
warnings . warn (
" You are calling .generate() with the `input_ids` being on a device type different "
f " than your model ' s device. `input_ids` is on { input_ids . device . type } , whereas the model "
f " is on { self . device . type } . You may experience unexpected behaviors or slower generation. "
" Please make sure that you have put `input_ids` to the "
f " correct device by calling for example input_ids = input_ids.to( ' { self . device . type } ' ) before "
" running `.generate()`. " ,
UserWarning ,
)
# 8. prepare distribution pre_processing samplers
logits_processor = self . _get_logits_processor (
generation_config = generation_config ,
input_ids_seq_length = input_ids_length ,
encoder_input_ids = inputs_tensor ,
prefix_allowed_tokens_fn = prefix_allowed_tokens_fn ,
logits_processor = logits_processor ,
model_kwargs = model_kwargs ,
negative_prompt_ids = negative_prompt_ids ,
negative_prompt_attention_mask = negative_prompt_attention_mask ,
)
# 9. prepare stopping criteria
stopping_criteria = self . _get_stopping_criteria (
generation_config = generation_config , stopping_criteria = stopping_criteria
)
# 10. go into different generation modes
# 11. prepare logits warper
logits_warper = self . _get_logits_warper ( generation_config )
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids , model_kwargs = self . _expand_inputs_for_generation (
input_ids = input_ids ,
expand_size = generation_config . num_return_sequences ,
is_encoder_decoder = self . config . is_encoder_decoder ,
* * model_kwargs ,
)
# 13. run sample
return self . sample_base (
input_ids ,
logits_processor = logits_processor ,
logits_warper = logits_warper ,
stopping_criteria = stopping_criteria ,
pad_token_id = generation_config . pad_token_id ,
eos_token_id = generation_config . eos_token_id ,
output_scores = generation_config . output_scores ,
return_dict_in_generate = generation_config . return_dict_in_generate ,
synced_gpus = synced_gpus ,
streamer = streamer ,
* * model_kwargs ,
)
def sample_base (
self ,
input_ids : torch . LongTensor ,
logits_processor : Optional [ LogitsProcessorList ] = None ,
stopping_criteria : Optional [ StoppingCriteriaList ] = None ,
logits_warper : Optional [ LogitsProcessorList ] = None ,
max_length : Optional [ int ] = None ,
pad_token_id : Optional [ int ] = None ,
eos_token_id : Optional [ Union [ int , List [ int ] ] ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
output_scores : Optional [ bool ] = None ,
return_dict_in_generate : Optional [ bool ] = None ,
synced_gpus : bool = False ,
streamer : Optional [ " BaseStreamer " ] = None ,
* * model_kwargs ,
) :
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList ( )
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList ( )
# if max_length is not None:
# warnings.warn(
# "`max_length` is deprecated in this function, use"
# " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
# UserWarning,
# )
# stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList ( )
pad_token_id = pad_token_id if pad_token_id is not None else self . generation_config . pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self . generation_config . eos_token_id
if isinstance ( eos_token_id , int ) :
eos_token_id = [ eos_token_id ]
eos_token_id_tensor = torch . tensor ( eos_token_id ) . to ( input_ids . device ) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self . generation_config . output_scores
output_attentions = (
output_attentions if output_attentions is not None else self . generation_config . output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self . generation_config . output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self . generation_config . return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = ( ) if ( return_dict_in_generate and output_scores ) else None
decoder_attentions = ( ) if ( return_dict_in_generate and output_attentions ) else None
cross_attentions = ( ) if ( return_dict_in_generate and output_attentions ) else None
decoder_hidden_states = ( ) if ( return_dict_in_generate and output_hidden_states ) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self . config . is_encoder_decoder :
encoder_attentions = model_kwargs [ " encoder_outputs " ] . get ( " attentions " ) if output_attentions else None
encoder_hidden_states = (
model_kwargs [ " encoder_outputs " ] . get ( " hidden_states " ) if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = torch . ones ( input_ids . shape [ 0 ] , dtype = torch . long , device = input_ids . device )
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
while True :
# if synced_gpus:
# # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# # The following logic allows an early break if all peers finished generating their sequence
# this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# # send 0.0 if we finished, 1.0 otherwise
# dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# # did all peers finish? the reduced sum will be 0.0 then
# if this_peer_finished_flag.item() == 0.0:
# break
# prepare model inputs
model_inputs = self . prepare_inputs_for_generation ( input_ids , * * model_kwargs )
# forward pass to get next token
outputs = self (
* * model_inputs ,
return_dict = True ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
)
if synced_gpus and this_peer_finished :
continue # don't waste resources running the code we don't need
next_token_logits = outputs . logits [ : , - 1 , : ]
# pre-process distribution
next_token_scores = logits_processor ( input_ids , next_token_logits )
next_token_scores = logits_warper ( input_ids , next_token_scores )
# Store scores, attentions and hidden_states when required
if return_dict_in_generate :
if output_scores :
scores + = ( next_token_scores , )
if output_attentions :
decoder_attentions + = (
( outputs . decoder_attentions , ) if self . config . is_encoder_decoder else ( outputs . attentions , )
)
if self . config . is_encoder_decoder :
cross_attentions + = ( outputs . cross_attentions , )
if output_hidden_states :
decoder_hidden_states + = (
( outputs . decoder_hidden_states , )
if self . config . is_encoder_decoder
else ( outputs . hidden_states , )
)
# sample
probs = nn . functional . softmax ( next_token_scores , dim = - 1 )
next_tokens = torch . multinomial ( probs , num_samples = 1 ) . squeeze ( 1 )
# finished sentences should have their next token be a padding token
if eos_token_id is not None :
if pad_token_id is None :
raise ValueError ( " If `eos_token_id` is defined, make sure that `pad_token_id` is defined. " )
next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( 1 - unfinished_sequences )
# update generated ids, model inputs, and length for next step
input_ids = torch . cat ( [ input_ids , next_tokens [ : , None ] ] , dim = - 1 )
if streamer is not None :
streamer . put ( next_tokens . cpu ( ) )
model_kwargs = self . _update_model_kwargs_for_generation (
outputs , model_kwargs , is_encoder_decoder = self . config . is_encoder_decoder
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None :
unfinished_sequences = unfinished_sequences . mul (
next_tokens . tile ( eos_token_id_tensor . shape [ 0 ] , 1 ) . ne ( eos_token_id_tensor . unsqueeze ( 1 ) ) . prod ( dim = 0 )
)
# stop when each sentence is finished
if unfinished_sequences . max ( ) == 0 :
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria ( input_ids , scores ) :
this_peer_finished = True
if this_peer_finished and not synced_gpus :
break
if streamer is not None :
streamer . end ( )
return input_ids
# def backward(
# self,
# tokenizer,
# query: str,
# ):
# inputs = tokenizer.build_chat_input(query, history=[], role="user")
# inputs = inputs.to(next(self.parameters()).device)
# generation_config = copy.deepcopy(self.generation_config)
# inputs_tensor = inputs["input_ids"]
# input_ids = inputs_tensor.repeat_interleave(
# generation_config.num_return_sequences, dim=0
# )
# input_ids_in = input_ids
# batch_size, seq_length = input_ids_in.shape
# position_ids_in = (
# torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
# .unsqueeze(0)
# .repeat(batch_size, 1)
# )
# model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
# probs, next_tokens = self.transformer(
# **model_inputs,
# output_hidden_states=None,
# tokenizer=tokenizer,
# )
# next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# # probs_target = probs
# # probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1
# loss = probs[0, next_tokens]
# loss.backward()
# return loss
2024-01-03 20:26:26 +08:00
class RotaryEmbedding ( torch . nn . Module ) :
def __init__ ( self , dim , base = 10000 ) :
super ( ) . __init__ ( )
self . dim = dim
self . base = base
inv_freq = 1.0 / ( base * * ( torch . arange ( 0 , dim , 2 ) . float ( ) / dim ) )
self . register_buffer ( " inv_freq " , inv_freq , persistent = False )
if importlib . util . find_spec ( " einops " ) is None :
raise RuntimeError ( " einops is required for Rotary Embedding " )
self . _rotary_pos_emb_cache = None
self . _seq_len_cached = 0
self . _ntk_alpha_cached = 1.0
self . _ntk_alpha_cached_list = [ 1.0 ]
def update_rotary_pos_emb_cache ( self , seqlen , ntk_alpha = 1.0 ) :
if seqlen > self . _seq_len_cached or ntk_alpha != self . _ntk_alpha_cached :
base = self . base * ntk_alpha * * ( self . dim / ( self . dim - 2 ) )
self . inv_freq = 1.0 / (
base
* * (
torch . arange ( 0 , self . dim , 2 , device = self . inv_freq . device ) . float ( )
/ self . dim
)
)
self . _seq_len_cached = max ( 2 * seqlen , 16 )
self . _ntk_alpha_cached = ntk_alpha
seq = torch . arange ( self . _seq_len_cached , device = self . inv_freq . device )
freqs = torch . outer ( seq . type_as ( self . inv_freq ) , self . inv_freq )
emb = torch . cat ( ( freqs , freqs ) , dim = - 1 )
from einops import rearrange
emb = rearrange ( emb , " n d -> 1 n 1 d " )
cos , sin = emb . cos ( ) , emb . sin ( )
self . _rotary_pos_emb_cache = [ cos , sin ]
def forward ( self , max_seq_len , ntk_alpha = 1.0 ) :
self . update_rotary_pos_emb_cache ( max_seq_len , ntk_alpha )
cos , sin = self . _rotary_pos_emb_cache
return [ cos [ : , : max_seq_len ] , sin [ : , : max_seq_len ] ]
def _rotate_half ( x ) :
from einops import rearrange
x = rearrange ( x , " ... (j d) -> ... j d " , j = 2 )
x1 , x2 = x . unbind ( dim = - 2 )
return torch . cat ( ( - x2 , x1 ) , dim = - 1 )
def apply_rotary_pos_emb ( t , freqs ) :
""" Apply rotary embedding to the first rotary_dim of the iput
Arguments :
t ( tensor ( batch_size , seq_len , n_head , head_dim ) ) :
the input embedding / hidden states
freqs ( list [ tensor ( 1 , seq_len , 1 , rotary_dim ) , tensor ( 1 , seq_len , 1 , rotary_dim ) ] ) :
2024-01-03 21:03:27 +08:00
the cached cos / sin position embeddings
2024-01-03 20:26:26 +08:00
"""
rot_dim = freqs [ 0 ] . shape [ - 1 ]
cos , sin = freqs
t_float = t . float ( )
if apply_rotary_emb_func is not None and t . is_cuda :
2024-01-03 21:03:27 +08:00
# apply_rotary_emb in flash_attn requires cos/sin to be of
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
2024-01-03 20:26:26 +08:00
# to the first rotary_dim of the input
cos = cos . squeeze ( 0 ) . squeeze ( 1 ) [ : , : rot_dim / / 2 ]
sin = sin . squeeze ( 0 ) . squeeze ( 1 ) [ : , : rot_dim / / 2 ]
return apply_rotary_emb_func ( t_float , cos , sin ) . type_as ( t )
else :
t_rot , t_pass = t_float [ . . . , : rot_dim ] , t_float [ . . . , rot_dim : ]
t_rot = ( t_rot * cos ) + ( _rotate_half ( t_rot ) * sin )
return torch . cat ( ( t_rot , t_pass ) , dim = - 1 ) . type_as ( t )
class RMSNorm ( torch . nn . Module ) :
def __init__ ( self , dim : int , eps : float = 1e-6 ) :
super ( ) . __init__ ( )
self . eps = eps
self . weight = nn . Parameter ( torch . ones ( dim ) )
def _norm ( self , x ) :
return x * torch . rsqrt ( x . pow ( 2 ) . mean ( - 1 , keepdim = True ) + self . eps )
def forward ( self , x ) :
if rms_norm is not None and x . is_cuda :
return rms_norm ( x , self . weight , self . eps )
else :
output = self . _norm ( x . float ( ) ) . type_as ( x )
2024-01-03 21:03:27 +08:00
return output * self . weight