Format qwen model.

This commit is contained in:
Colin 2024-01-07 16:23:04 +08:00
parent 255a2ff71c
commit 611396b656
1 changed files with 309 additions and 156 deletions

View File

@ -38,7 +38,9 @@ from torch import nn
SUPPORT_CUDA = torch.cuda.is_available() SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 SUPPORT_TORCH2 = (
hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2
)
from configuration_qwen import QWenConfig from configuration_qwen import QWenConfig
@ -70,6 +72,7 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
apply_rotary_emb_func = None apply_rotary_emb_func = None
rms_norm = None rms_norm = None
def quantize_cache_v(fdata, bits, qmax, qmin): def quantize_cache_v(fdata, bits, qmax, qmin):
# b, s, head, h-dim->b, head, s, h-dim # b, s, head, h-dim->b, head, s, h-dim
qtype = torch.uint8 qtype = torch.uint8
@ -85,17 +88,19 @@ def quantize_cache_v(fdata, bits, qmax, qmin):
qmin = qmin.to(device) qmin = qmin.to(device)
scale = (fmax - fmin) / (qmax - qmin) scale = (fmax - fmin) / (qmax - qmin)
zero = qmin - fmin / scale zero = qmin - fmin / scale
scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
# Quantize # Quantize
res_data = fdata / scale + zero res_data = fdata / scale + zero
qdata = torch.clamp(res_data, qmin, qmax).to(qtype) qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
return qdata.contiguous(), scale, zero return qdata.contiguous(), scale, zero
def dequantize_cache_torch(qdata, scale, zero): def dequantize_cache_torch(qdata, scale, zero):
data = scale * (qdata - zero) data = scale * (qdata - zero)
return data return data
class QWenAttention(nn.Module): class QWenAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -138,12 +143,20 @@ class QWenAttention(nn.Module):
self.register_buffer("logn_tensor", logn_tensor, persistent=False) self.register_buffer("logn_tensor", logn_tensor, persistent=False)
self.attn_dropout = nn.Dropout(config.attn_dropout_prob) self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False self.softmax_in_fp32 = (
self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False )
self.use_cache_quantization = (
config.use_cache_quantization
if hasattr(config, "use_cache_quantization")
else False
)
self.use_cache_kernel = (
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
)
cache_dtype = torch.float cache_dtype = torch.float
if self.bf16: if self.bf16:
cache_dtype=torch.bfloat16 cache_dtype = torch.bfloat16
elif config.fp16: elif config.fp16:
cache_dtype = torch.float16 cache_dtype = torch.float16
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
@ -152,19 +165,25 @@ class QWenAttention(nn.Module):
if config.use_cache_quantization and config.use_cache_kernel: if config.use_cache_quantization and config.use_cache_kernel:
# pre check if the support files existing # pre check if the support files existing
module_root = pathlib.Path(__file__).parent module_root = pathlib.Path(__file__).parent
src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") src_files = (
if any(not (module_root/src).is_file() for src in src_files): "cache_autogptq_cuda_256.cpp",
"cache_autogptq_cuda_kernel_256.cu",
)
if any(not (module_root / src).is_file() for src in src_files):
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
self.cache_kernels = None self.cache_kernels = None
else: else:
try: try:
from .cpp_kernels import cache_autogptq_cuda_256 from .cpp_kernels import cache_autogptq_cuda_256
self.cache_kernels = cache_autogptq_cuda_256 self.cache_kernels = cache_autogptq_cuda_256
except ImportError: except ImportError:
warnings.warn("Failed to import KV cache kernels.") warnings.warn("Failed to import KV cache kernels.")
self.cache_kernels = None self.cache_kernels = None
def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): def _attn(
self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None
):
device = query.device device = query.device
if self.use_cache_quantization: if self.use_cache_quantization:
qk, qk_scale, qk_zero = key qk, qk_scale, qk_zero = key
@ -172,11 +191,18 @@ class QWenAttention(nn.Module):
shape = query.shape[:-1] + (qk.shape[-2],) shape = query.shape[:-1] + (qk.shape[-2],)
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
self.cache_kernels.vecquant8matmul_batched_faster_old( self.cache_kernels.vecquant8matmul_batched_faster_old(
query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), query.contiguous()
if query.dtype == torch.float16
else query.to(torch.float16).contiguous(),
qk.transpose(-1, -2).contiguous(), qk.transpose(-1, -2).contiguous(),
attn_weights, attn_weights,
qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), qk_scale.contiguous()
qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) if qk_scale.dtype == torch.float16
else qk_scale.to(torch.float16).contiguous(),
qk_zero.contiguous()
if qk_zero.dtype == torch.float16
else qk_zero.to(torch.float16).contiguous(),
)
# attn_weights = attn_weights.to(query.dtype).contiguous() # attn_weights = attn_weights.to(query.dtype).contiguous()
else: else:
key = dequantize_cache_torch(qk, qk_scale, qk_zero) key = dequantize_cache_torch(qk, qk_scale, qk_zero)
@ -189,7 +215,7 @@ class QWenAttention(nn.Module):
size_temp = value[0].size(-1) size_temp = value[0].size(-1)
else: else:
size_temp = value.size(-1) size_temp = value.size(-1)
attn_weights = attn_weights / (size_temp ** 0.5) attn_weights = attn_weights / (size_temp**0.5)
mask_value = torch.finfo(attn_weights.dtype).min mask_value = torch.finfo(attn_weights.dtype).min
if causal_mask is not None: if causal_mask is not None:
@ -217,11 +243,18 @@ class QWenAttention(nn.Module):
shape = attn_weights.shape[:-1] + (query.shape[-1],) shape = attn_weights.shape[:-1] + (query.shape[-1],)
attn_output = torch.zeros(shape, dtype=torch.float16, device=device) attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), attn_weights.contiguous()
if attn_weights.dtype == torch.float16
else attn_weights.to(torch.float16).contiguous(),
qv.contiguous(), # dtype: int32 qv.contiguous(), # dtype: int32
attn_output, attn_output,
qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), qv_scale.contiguous()
qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) if qv_scale.dtype == torch.float16
else qv_scale.to(torch.float16).contiguous(),
qv_zero.contiguous()
if qv_zero.dtype == torch.float16
else qv_zero.to(torch.float16).contiguous(),
)
if attn_output.dtype != query.dtype: if attn_output.dtype != query.dtype:
attn_output = attn_output.to(query.dtype) attn_output = attn_output.to(query.dtype)
attn_weights = attn_weights.to(query.dtype) attn_weights = attn_weights.to(query.dtype)
@ -283,21 +316,26 @@ class QWenAttention(nn.Module):
rotary_pos_emb = (rotary_pos_emb,) * 2 rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference # Slice the pos emb for current inference
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] query_list += [
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)
]
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0) query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0) key = torch.cat(key_list, dim=0)
if self.use_cache_quantization: if self.use_cache_quantization:
key = quantize_cache_v(key.permute(0, 2, 1, 3), key = quantize_cache_v(
bits=8, key.permute(0, 2, 1, 3),
qmin=self.cache_qmin, bits=8,
qmax=self.cache_qmax) qmin=self.cache_qmin,
value = quantize_cache_v(value.permute(0, 2, 1, 3), qmax=self.cache_qmax,
bits=8, )
qmin=self.cache_qmin, value = quantize_cache_v(
qmax=self.cache_qmax) value.permute(0, 2, 1, 3),
bits=8,
qmin=self.cache_qmin,
qmax=self.cache_qmax,
)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1] past_key, past_value = layer_past[0], layer_past[1]
@ -305,12 +343,16 @@ class QWenAttention(nn.Module):
# use_cache_quantization: # use_cache_quantization:
# present=((q_key,key_scale,key_zero_point), # present=((q_key,key_scale,key_zero_point),
# (q_value,value_scale,value_zero_point)) # (q_value,value_scale,value_zero_point))
key = (torch.cat((past_key[0], key[0]), dim=2), key = (
torch.cat((past_key[1], key[1]), dim=2), torch.cat((past_key[0], key[0]), dim=2),
torch.cat((past_key[2], key[2]), dim=2)) torch.cat((past_key[1], key[1]), dim=2),
value = (torch.cat((past_value[0], value[0]), dim=2), torch.cat((past_key[2], key[2]), dim=2),
torch.cat((past_value[1], value[1]), dim=2), )
torch.cat((past_value[2], value[2]), dim=2)) value = (
torch.cat((past_value[0], value[0]), dim=2),
torch.cat((past_value[1], value[1]), dim=2),
torch.cat((past_value[2], value[2]), dim=2),
)
else: else:
# not use_cache_quantization: # not use_cache_quantization:
# present=(key,value) # present=(key,value)
@ -347,11 +389,11 @@ class QWenAttention(nn.Module):
if not self.use_cache_quantization and SUPPORT_TORCH2: if not self.use_cache_quantization and SUPPORT_TORCH2:
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
-1, -1, causal_mask.size(2), -1
)
if causal_mask is not None: if causal_mask is not None:
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) attention_mask = attention_mask.masked_fill(
~causal_mask, torch.finfo(query.dtype).min
)
else: else:
attention_mask = causal_mask attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention( attn_output = F.scaled_dot_product_attention(
@ -362,16 +404,16 @@ class QWenAttention(nn.Module):
attn_output, attn_weight = self._attn( attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask query, key, value, causal_mask, attention_mask, head_mask
) )
context_layer = self._merge_heads( context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output, self.num_heads, self.head_dim
)
attn_output = self.c_proj(context_layer) attn_output = self.c_proj(context_layer)
outputs = (attn_output, present) outputs = (attn_output, present)
if output_attentions: if output_attentions:
if not self.use_cache_quantization and SUPPORT_TORCH2: if not self.use_cache_quantization and SUPPORT_TORCH2:
raise ValueError("Cannot output attentions while using scaled_dot_product_attention") raise ValueError(
"Cannot output attentions while using scaled_dot_product_attention"
)
else: else:
outputs += (attn_weight,) outputs += (attn_weight,)
@ -507,7 +549,11 @@ class QWenModel(QWenPreTrainedModel):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False self.use_cache_quantization = (
self.config.use_cache_quantization
if hasattr(self.config, "use_cache_quantization")
else False
)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.use_dynamic_ntk = config.use_dynamic_ntk self.use_dynamic_ntk = config.use_dynamic_ntk
@ -521,25 +567,14 @@ class QWenModel(QWenPreTrainedModel):
self.rotary_ndims = None self.rotary_ndims = None
else: else:
assert config.rotary_pct < 1 assert config.rotary_pct < 1
self.rotary_ndims = int( self.rotary_ndims = int(config.kv_channels * config.rotary_pct)
config.kv_channels * config.rotary_pct dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
)
dim = (
self.rotary_ndims
if self.rotary_ndims is not None
else config.kv_channels
)
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.is_fp32 = not (config.bf16 or config.fp16) self.is_fp32 = not (config.bf16 or config.fp16)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [QWenBlock(config) for i in range(config.num_hidden_layers)]
QWenBlock(
config
)
for i in range(config.num_hidden_layers)
]
) )
self.ln_f = RMSNorm( self.ln_f = RMSNorm(
self.embed_dim, self.embed_dim,
@ -659,7 +694,12 @@ class QWenModel(QWenPreTrainedModel):
else: else:
ntk_alpha_list = [] ntk_alpha_list = []
if attention_mask is not None and kv_seq_len > self.seq_length: if attention_mask is not None and kv_seq_len > self.seq_length:
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) true_seq_lens = (
attention_mask.squeeze(1)
.squeeze(1)
.eq(0)
.sum(dim=-1, dtype=torch.int32)
)
for i in range(hidden_states.size()[0]): for i in range(hidden_states.size()[0]):
true_seq_len = true_seq_lens[i].item() true_seq_len = true_seq_lens[i].item()
ntk_alpha = self.get_ntk_alpha(true_seq_len) ntk_alpha = self.get_ntk_alpha(true_seq_len)
@ -669,7 +709,8 @@ class QWenModel(QWenPreTrainedModel):
ntk_alpha_list.append(ntk_alpha) ntk_alpha_list.append(ntk_alpha)
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
rotary_pos_emb_list = [ rotary_pos_emb_list = [
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
for ntk_alpha in ntk_alpha_list
] ]
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
@ -686,7 +727,6 @@ class QWenModel(QWenPreTrainedModel):
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@ -727,7 +767,9 @@ class QWenModel(QWenPreTrainedModel):
presents = presents + (outputs[1],) presents = presents + (outputs[1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
@ -756,7 +798,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
super().__init__(config) super().__init__(config)
assert ( assert (
config.bf16 + config.fp16 + config.fp32 <= 1 config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" ), 'Only one of "bf16", "fp16", "fp32" can be true'
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
@ -764,27 +806,35 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if SUPPORT_BF16: if SUPPORT_BF16:
logger.warn( logger.warn(
"The model is automatically converting to bf16 for faster inference. " "The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." 'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
) )
config.bf16 = True config.bf16 = True
elif SUPPORT_FP16: elif SUPPORT_FP16:
logger.warn( logger.warn(
"The model is automatically converting to fp16 for faster inference. " "The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." 'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
) )
config.fp16 = True config.fp16 = True
else: else:
config.fp32 = True config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") logger.warn(
'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".'
)
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") logger.warn(
"Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster"
)
if config.fp32: if config.fp32:
if SUPPORT_BF16: if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") logger.warn(
'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".'
)
elif SUPPORT_FP16: elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") logger.warn(
'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".'
)
self.transformer = QWenModel(config) self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@ -845,7 +895,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = ( return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict return_dict if return_dict is not None else self.config.use_return_dict
) )
@ -887,7 +936,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# ) # )
# loss.backward() # loss.backward()
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
@ -904,7 +952,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def _reorder_cache( def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]: ) -> Tuple[Tuple[torch.Tensor]]:
return tuple( return tuple(
tuple( tuple(
past_state.index_select(0, beam_idx.to(past_state.device)) past_state.index_select(0, beam_idx.to(past_state.device))
@ -924,10 +971,14 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
**kwargs, **kwargs,
) -> Tuple[str, HistoryType]: ) -> Tuple[str, HistoryType]:
generation_config = generation_config if generation_config is not None else self.generation_config generation_config = (
generation_config
if generation_config is not None
else self.generation_config
)
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
if history is None: if history is None:
history = [] history = []
else: else:
@ -937,7 +988,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if stop_words_ids is None: if stop_words_ids is None:
stop_words_ids = [] stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None) max_window_size = kwargs.get("max_window_size", None)
if max_window_size is None: if max_window_size is None:
max_window_size = generation_config.max_window_size max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context( raw_text, context_tokens = make_context(
@ -949,18 +1000,18 @@ class QWenLMHeadModel(QWenPreTrainedModel):
chat_format=generation_config.chat_format, chat_format=generation_config.chat_format,
) )
stop_words_ids.extend(get_stop_words_ids( stop_words_ids.extend(
generation_config.chat_format, tokenizer get_stop_words_ids(generation_config.chat_format, tokenizer)
)) )
input_ids = torch.tensor([context_tokens]).to(self.device) input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate( outputs = self.generate(
input_ids, input_ids,
stop_words_ids=stop_words_ids, stop_words_ids=stop_words_ids,
return_dict_in_generate=False, return_dict_in_generate=False,
generation_config=generation_config, generation_config=generation_config,
**kwargs, **kwargs,
) )
response = decode_tokens( response = decode_tokens(
outputs[0], outputs[0],
tokenizer, tokenizer,
@ -968,7 +1019,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
context_length=len(context_tokens), context_length=len(context_tokens),
chat_format=generation_config.chat_format, chat_format=generation_config.chat_format,
verbose=False, verbose=False,
errors='replace' errors="replace",
) )
# as history is a copy of the user inputs, # as history is a copy of the user inputs,
@ -980,24 +1031,28 @@ class QWenLMHeadModel(QWenPreTrainedModel):
return response, history return response, history
def chat_stream( def chat_stream(
self, self,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
query: str, query: str,
history: Optional[HistoryType], history: Optional[HistoryType],
system: str = "You are a helpful assistant.", system: str = "You are a helpful assistant.",
stop_words_ids: Optional[List[List[int]]] = None, stop_words_ids: Optional[List[List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
**kwargs, **kwargs,
) -> Generator[str, Any, None]: ) -> Generator[str, Any, None]:
generation_config = generation_config if generation_config is not None else self.generation_config generation_config = (
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT generation_config
if generation_config is not None
else self.generation_config
)
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
if history is None: if history is None:
history = [] history = []
if stop_words_ids is None: if stop_words_ids is None:
stop_words_ids = [] stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None) max_window_size = kwargs.get("max_window_size", None)
if max_window_size is None: if max_window_size is None:
max_window_size = generation_config.max_window_size max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context( raw_text, context_tokens = make_context(
@ -1009,9 +1064,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
chat_format=generation_config.chat_format, chat_format=generation_config.chat_format,
) )
stop_words_ids.extend(get_stop_words_ids( stop_words_ids.extend(
generation_config.chat_format, tokenizer get_stop_words_ids(generation_config.chat_format, tokenizer)
)) )
if stop_words_ids is not None: if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor( stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids, stop_words_ids=stop_words_ids,
@ -1023,22 +1078,31 @@ class QWenLMHeadModel(QWenPreTrainedModel):
logits_processor.append(stop_words_logits_processor) logits_processor.append(stop_words_logits_processor)
input_ids = torch.tensor([context_tokens]).to(self.device) input_ids = torch.tensor([context_tokens]).to(self.device)
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig from transformers_stream_generator.main import (
NewGenerationMixin,
StreamGenerationConfig,
)
self.__class__.generate_stream = NewGenerationMixin.generate self.__class__.generate_stream = NewGenerationMixin.generate
self.__class__.sample_stream = NewGenerationMixin.sample_stream self.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) stream_config = StreamGenerationConfig(
**generation_config.to_dict(), do_stream=True
)
def stream_generator(): def stream_generator():
outputs = [] outputs = []
for token in self.generate_stream( for token in self.generate_stream(
input_ids, input_ids,
return_dict_in_generate=False, return_dict_in_generate=False,
generation_config=stream_config, generation_config=stream_config,
logits_processor=logits_processor, logits_processor=logits_processor,
seed=-1, seed=-1,
**kwargs): **kwargs,
):
outputs.append(token.item()) outputs.append(token.item())
yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') yield tokenizer.decode(
outputs, skip_special_tokens=True, errors="ignore"
)
return stream_generator() return stream_generator()
@ -1056,7 +1120,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = generation_config if generation_config is not None else self.generation_config generation_config = (
generation_config
if generation_config is not None
else self.generation_config
)
# Process stop_words_ids. # Process stop_words_ids.
stop_words_ids = kwargs.pop("stop_words_ids", None) stop_words_ids = kwargs.pop("stop_words_ids", None)
@ -1093,7 +1161,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None, synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None, assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
@ -1101,9 +1171,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
negative_prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
if synced_gpus is None: if synced_gpus is None:
synced_gpus = False synced_gpus = False
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class() self._validate_model_class()
@ -1114,8 +1183,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# two conditions must be met # two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field); # 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same). # 2) the generation config must have seen no modification since its creation (the hash is the same).
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( if (
self.generation_config self.generation_config._from_model_config
and self.generation_config._original_object_hash
== hash(self.generation_config)
): ):
new_generation_config = GenerationConfig.from_model_config(self.config) new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config: if new_generation_config != self.generation_config:
@ -1129,15 +1200,26 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config = self.generation_config generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs model_kwargs = generation_config.update(
**kwargs
) # All unused kwargs must be model kwargs
generation_config.validate() generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = (
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if (
generation_config.pad_token_id is None
and generation_config.eos_token_id is not None
):
if model_kwargs.get("attention_mask", None) is None: if model_kwargs.get("attention_mask", None) is None:
logger.warning( logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe " "The attention mask and the pad token id were not set. As a consequence, you may observe "
@ -1146,7 +1228,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
eos_token_id = generation_config.eos_token_id eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list): if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0] eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") logger.warning(
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
)
generation_config.pad_token_id = eos_token_id generation_config.pad_token_id = eos_token_id
# 3. Define model inputs # 3. Define model inputs
@ -1169,12 +1253,22 @@ class QWenLMHeadModel(QWenPreTrainedModel):
else: else:
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) accepts_attention_mask = "attention_mask" in set(
inspect.signature(self.forward).parameters.keys()
)
requires_attention_mask = "encoder_outputs" not in model_kwargs requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: if (
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( model_kwargs.get("attention_mask", None) is None
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id and requires_attention_mask
and accepts_attention_mask
):
model_kwargs[
"attention_mask"
] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
) )
# decoder-only models should use left-padding for generation # decoder-only models should use left-padding for generation
@ -1184,7 +1278,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if ( if (
generation_config.pad_token_id is not None generation_config.pad_token_id is not None
and len(inputs_tensor.shape) == 2 and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
> 0
): ):
logger.warning( logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct " "A decoder-only architecture is being used, but right-padding was detected! For correct "
@ -1209,14 +1304,21 @@ class QWenLMHeadModel(QWenPreTrainedModel):
device=inputs_tensor.device, device=inputs_tensor.device,
) )
else: else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
if streamer is not None: if streamer is not None:
streamer.put(input_ids.cpu()) streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria. # 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
if generation_config.max_new_tokens is not None: if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None: if not has_default_max_length and generation_config.max_length is not None:
logger.warning( logger.warning(
@ -1225,8 +1327,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
"Please refer to the documentation for more information. " "Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
) )
generation_config.max_length = generation_config.max_new_tokens + input_ids_length generation_config.max_length = (
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) generation_config.max_new_tokens + input_ids_length
)
self._validate_generated_length(
generation_config, input_ids_length, has_default_max_length
)
# 7. determine generation mode # 7. determine generation mode
generation_mode = self._get_generation_mode(generation_config, assistant_model) generation_mode = self._get_generation_mode(generation_config, assistant_model)
@ -1264,7 +1370,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config=generation_config, stopping_criteria=stopping_criteria generation_config=generation_config, stopping_criteria=stopping_criteria
) )
# 10. go into different generation modes # 10. go into different generation modes
# 11. prepare logits warper # 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config)
@ -1291,7 +1397,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
**model_kwargs, **model_kwargs,
) )
def sample_base( def sample_base(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
@ -1309,10 +1414,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
**model_kwargs, **model_kwargs,
): ):
# init values # init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_processor = (
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
# if max_length is not None: # if max_length is not None:
# warnings.warn( # warnings.warn(
# "`max_length` is deprecated in this function, use" # "`max_length` is deprecated in this function, use"
@ -1320,18 +1430,40 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# UserWarning, # UserWarning,
# ) # )
# stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) # stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() logits_warper = (
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id logits_warper if logits_warper is not None else LogitsProcessorList()
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id )
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.generation_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.generation_config.eos_token_id
)
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None eos_token_id_tensor = (
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores torch.tensor(eos_token_id).to(input_ids.device)
if eos_token_id is not None
else None
)
output_scores = (
output_scores
if output_scores is not None
else self.generation_config.output_scores
)
output_attentions = ( output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions output_attentions
if output_attentions is not None
else self.generation_config.output_attentions
) )
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states output_hidden_states
if output_hidden_states is not None
else self.generation_config.output_hidden_states
) )
return_dict_in_generate = ( return_dict_in_generate = (
return_dict_in_generate return_dict_in_generate
@ -1341,19 +1473,33 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = (
cross_attentions = () if (return_dict_in_generate and output_attentions) else None () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None )
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder: if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_attentions = (
model_kwargs["encoder_outputs"].get("attentions")
if output_attentions
else None
)
encoder_hidden_states = ( encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None model_kwargs["encoder_outputs"].get("hidden_states")
if output_hidden_states
else None
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
# auto-regressive generation # auto-regressive generation
@ -1394,7 +1540,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
scores += (next_token_scores,) scores += (next_token_scores,)
if output_attentions: if output_attentions:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
) )
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,) cross_attentions += (outputs.cross_attentions,)
@ -1413,8 +1561,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# finished sentences should have their next token be a padding token # finished sentences should have their next token be a padding token
if eos_token_id is not None: if eos_token_id is not None:
if pad_token_id is None: if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") raise ValueError(
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
# update generated ids, model inputs, and length for next step # update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@ -1427,7 +1579,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# if eos_token was found in one sentence, set sentence to finished # if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None: if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul( unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) next_tokens.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
) )
# stop when each sentence is finished # stop when each sentence is finished
@ -1446,7 +1600,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
return input_ids return input_ids
# def backward( # def backward(
# self, # self,
# tokenizer, # tokenizer,
@ -1539,20 +1692,20 @@ def _rotate_half(x):
def apply_rotary_pos_emb(t, freqs): def apply_rotary_pos_emb(t, freqs):
""" Apply rotary embedding to the first rotary_dim of the iput """Apply rotary embedding to the first rotary_dim of the iput
Arguments: Arguments:
t (tensor(batch_size, seq_len, n_head, head_dim)): t (tensor(batch_size, seq_len, n_head, head_dim)):
the input embedding/hidden states the input embedding/hidden states
freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]): freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
the cached cos/sin position embeddings the cached cos/sin position embeddings
""" """
rot_dim = freqs[0].shape[-1] rot_dim = freqs[0].shape[-1]
cos, sin = freqs cos, sin = freqs
t_float = t.float() t_float = t.float()
if apply_rotary_emb_func is not None and t.is_cuda: if apply_rotary_emb_func is not None and t.is_cuda:
# apply_rotary_emb in flash_attn requires cos/sin to be of # apply_rotary_emb in flash_attn requires cos/sin to be of
# shape (seqlen, rotary_dim / 2) and apply rotary embedding # shape (seqlen, rotary_dim / 2) and apply rotary embedding
# to the first rotary_dim of the input # to the first rotary_dim of the input
cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2] cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2] sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]