Format qwen model.
This commit is contained in:
parent
255a2ff71c
commit
611396b656
|
@ -38,7 +38,9 @@ from torch import nn
|
|||
SUPPORT_CUDA = torch.cuda.is_available()
|
||||
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
||||
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
||||
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
||||
SUPPORT_TORCH2 = (
|
||||
hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2
|
||||
)
|
||||
|
||||
|
||||
from configuration_qwen import QWenConfig
|
||||
|
@ -70,6 +72,7 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
|
|||
apply_rotary_emb_func = None
|
||||
rms_norm = None
|
||||
|
||||
|
||||
def quantize_cache_v(fdata, bits, qmax, qmin):
|
||||
# b, s, head, h-dim->b, head, s, h-dim
|
||||
qtype = torch.uint8
|
||||
|
@ -85,17 +88,19 @@ def quantize_cache_v(fdata, bits, qmax, qmin):
|
|||
qmin = qmin.to(device)
|
||||
scale = (fmax - fmin) / (qmax - qmin)
|
||||
zero = qmin - fmin / scale
|
||||
scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
|
||||
zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
|
||||
scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
|
||||
zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
|
||||
# Quantize
|
||||
res_data = fdata / scale + zero
|
||||
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
|
||||
return qdata.contiguous(), scale, zero
|
||||
|
||||
|
||||
def dequantize_cache_torch(qdata, scale, zero):
|
||||
data = scale * (qdata - zero)
|
||||
return data
|
||||
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -138,12 +143,20 @@ class QWenAttention(nn.Module):
|
|||
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
||||
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
|
||||
self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
|
||||
self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
|
||||
self.softmax_in_fp32 = (
|
||||
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
||||
)
|
||||
self.use_cache_quantization = (
|
||||
config.use_cache_quantization
|
||||
if hasattr(config, "use_cache_quantization")
|
||||
else False
|
||||
)
|
||||
self.use_cache_kernel = (
|
||||
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
||||
)
|
||||
cache_dtype = torch.float
|
||||
if self.bf16:
|
||||
cache_dtype=torch.bfloat16
|
||||
cache_dtype = torch.bfloat16
|
||||
elif config.fp16:
|
||||
cache_dtype = torch.float16
|
||||
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
||||
|
@ -152,19 +165,25 @@ class QWenAttention(nn.Module):
|
|||
if config.use_cache_quantization and config.use_cache_kernel:
|
||||
# pre check if the support files existing
|
||||
module_root = pathlib.Path(__file__).parent
|
||||
src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
|
||||
if any(not (module_root/src).is_file() for src in src_files):
|
||||
src_files = (
|
||||
"cache_autogptq_cuda_256.cpp",
|
||||
"cache_autogptq_cuda_kernel_256.cu",
|
||||
)
|
||||
if any(not (module_root / src).is_file() for src in src_files):
|
||||
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
|
||||
self.cache_kernels = None
|
||||
else:
|
||||
try:
|
||||
from .cpp_kernels import cache_autogptq_cuda_256
|
||||
|
||||
self.cache_kernels = cache_autogptq_cuda_256
|
||||
except ImportError:
|
||||
warnings.warn("Failed to import KV cache kernels.")
|
||||
self.cache_kernels = None
|
||||
|
||||
def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
||||
def _attn(
|
||||
self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None
|
||||
):
|
||||
device = query.device
|
||||
if self.use_cache_quantization:
|
||||
qk, qk_scale, qk_zero = key
|
||||
|
@ -172,11 +191,18 @@ class QWenAttention(nn.Module):
|
|||
shape = query.shape[:-1] + (qk.shape[-2],)
|
||||
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
||||
self.cache_kernels.vecquant8matmul_batched_faster_old(
|
||||
query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
|
||||
query.contiguous()
|
||||
if query.dtype == torch.float16
|
||||
else query.to(torch.float16).contiguous(),
|
||||
qk.transpose(-1, -2).contiguous(),
|
||||
attn_weights,
|
||||
qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(),
|
||||
qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous())
|
||||
qk_scale.contiguous()
|
||||
if qk_scale.dtype == torch.float16
|
||||
else qk_scale.to(torch.float16).contiguous(),
|
||||
qk_zero.contiguous()
|
||||
if qk_zero.dtype == torch.float16
|
||||
else qk_zero.to(torch.float16).contiguous(),
|
||||
)
|
||||
# attn_weights = attn_weights.to(query.dtype).contiguous()
|
||||
else:
|
||||
key = dequantize_cache_torch(qk, qk_scale, qk_zero)
|
||||
|
@ -189,7 +215,7 @@ class QWenAttention(nn.Module):
|
|||
size_temp = value[0].size(-1)
|
||||
else:
|
||||
size_temp = value.size(-1)
|
||||
attn_weights = attn_weights / (size_temp ** 0.5)
|
||||
attn_weights = attn_weights / (size_temp**0.5)
|
||||
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
if causal_mask is not None:
|
||||
|
@ -217,11 +243,18 @@ class QWenAttention(nn.Module):
|
|||
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
||||
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
||||
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
|
||||
attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
|
||||
attn_weights.contiguous()
|
||||
if attn_weights.dtype == torch.float16
|
||||
else attn_weights.to(torch.float16).contiguous(),
|
||||
qv.contiguous(), # dtype: int32
|
||||
attn_output,
|
||||
qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(),
|
||||
qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous())
|
||||
qv_scale.contiguous()
|
||||
if qv_scale.dtype == torch.float16
|
||||
else qv_scale.to(torch.float16).contiguous(),
|
||||
qv_zero.contiguous()
|
||||
if qv_zero.dtype == torch.float16
|
||||
else qv_zero.to(torch.float16).contiguous(),
|
||||
)
|
||||
if attn_output.dtype != query.dtype:
|
||||
attn_output = attn_output.to(query.dtype)
|
||||
attn_weights = attn_weights.to(query.dtype)
|
||||
|
@ -283,21 +316,26 @@ class QWenAttention(nn.Module):
|
|||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
|
||||
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
||||
query_list += [
|
||||
apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)
|
||||
]
|
||||
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
|
||||
query = torch.cat(query_list, dim=0)
|
||||
key = torch.cat(key_list, dim=0)
|
||||
|
||||
if self.use_cache_quantization:
|
||||
key = quantize_cache_v(key.permute(0, 2, 1, 3),
|
||||
key = quantize_cache_v(
|
||||
key.permute(0, 2, 1, 3),
|
||||
bits=8,
|
||||
qmin=self.cache_qmin,
|
||||
qmax=self.cache_qmax)
|
||||
value = quantize_cache_v(value.permute(0, 2, 1, 3),
|
||||
qmax=self.cache_qmax,
|
||||
)
|
||||
value = quantize_cache_v(
|
||||
value.permute(0, 2, 1, 3),
|
||||
bits=8,
|
||||
qmin=self.cache_qmin,
|
||||
qmax=self.cache_qmax)
|
||||
|
||||
qmax=self.cache_qmax,
|
||||
)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
|
@ -305,12 +343,16 @@ class QWenAttention(nn.Module):
|
|||
# use_cache_quantization:
|
||||
# present=((q_key,key_scale,key_zero_point),
|
||||
# (q_value,value_scale,value_zero_point))
|
||||
key = (torch.cat((past_key[0], key[0]), dim=2),
|
||||
key = (
|
||||
torch.cat((past_key[0], key[0]), dim=2),
|
||||
torch.cat((past_key[1], key[1]), dim=2),
|
||||
torch.cat((past_key[2], key[2]), dim=2))
|
||||
value = (torch.cat((past_value[0], value[0]), dim=2),
|
||||
torch.cat((past_key[2], key[2]), dim=2),
|
||||
)
|
||||
value = (
|
||||
torch.cat((past_value[0], value[0]), dim=2),
|
||||
torch.cat((past_value[1], value[1]), dim=2),
|
||||
torch.cat((past_value[2], value[2]), dim=2))
|
||||
torch.cat((past_value[2], value[2]), dim=2),
|
||||
)
|
||||
else:
|
||||
# not use_cache_quantization:
|
||||
# present=(key,value)
|
||||
|
@ -347,11 +389,11 @@ class QWenAttention(nn.Module):
|
|||
|
||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.expand(
|
||||
-1, -1, causal_mask.size(2), -1
|
||||
)
|
||||
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
||||
if causal_mask is not None:
|
||||
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
||||
attention_mask = attention_mask.masked_fill(
|
||||
~causal_mask, torch.finfo(query.dtype).min
|
||||
)
|
||||
else:
|
||||
attention_mask = causal_mask
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
|
@ -362,16 +404,16 @@ class QWenAttention(nn.Module):
|
|||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
context_layer = self._merge_heads(
|
||||
attn_output, self.num_heads, self.head_dim
|
||||
)
|
||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = self.c_proj(context_layer)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||
raise ValueError("Cannot output attentions while using scaled_dot_product_attention")
|
||||
raise ValueError(
|
||||
"Cannot output attentions while using scaled_dot_product_attention"
|
||||
)
|
||||
else:
|
||||
outputs += (attn_weight,)
|
||||
|
||||
|
@ -507,7 +549,11 @@ class QWenModel(QWenPreTrainedModel):
|
|||
self.vocab_size = config.vocab_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embed_dim = config.hidden_size
|
||||
self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False
|
||||
self.use_cache_quantization = (
|
||||
self.config.use_cache_quantization
|
||||
if hasattr(self.config, "use_cache_quantization")
|
||||
else False
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||
|
@ -521,25 +567,14 @@ class QWenModel(QWenPreTrainedModel):
|
|||
self.rotary_ndims = None
|
||||
else:
|
||||
assert config.rotary_pct < 1
|
||||
self.rotary_ndims = int(
|
||||
config.kv_channels * config.rotary_pct
|
||||
)
|
||||
dim = (
|
||||
self.rotary_ndims
|
||||
if self.rotary_ndims is not None
|
||||
else config.kv_channels
|
||||
)
|
||||
self.rotary_ndims = int(config.kv_channels * config.rotary_pct)
|
||||
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
|
||||
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
||||
|
||||
self.is_fp32 = not (config.bf16 or config.fp16)
|
||||
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
QWenBlock(
|
||||
config
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
[QWenBlock(config) for i in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.ln_f = RMSNorm(
|
||||
self.embed_dim,
|
||||
|
@ -659,7 +694,12 @@ class QWenModel(QWenPreTrainedModel):
|
|||
else:
|
||||
ntk_alpha_list = []
|
||||
if attention_mask is not None and kv_seq_len > self.seq_length:
|
||||
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
|
||||
true_seq_lens = (
|
||||
attention_mask.squeeze(1)
|
||||
.squeeze(1)
|
||||
.eq(0)
|
||||
.sum(dim=-1, dtype=torch.int32)
|
||||
)
|
||||
for i in range(hidden_states.size()[0]):
|
||||
true_seq_len = true_seq_lens[i].item()
|
||||
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
||||
|
@ -669,7 +709,8 @@ class QWenModel(QWenPreTrainedModel):
|
|||
ntk_alpha_list.append(ntk_alpha)
|
||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||
rotary_pos_emb_list = [
|
||||
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
||||
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
|
||||
for ntk_alpha in ntk_alpha_list
|
||||
]
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
@ -686,7 +727,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
|
@ -727,7 +767,9 @@ class QWenModel(QWenPreTrainedModel):
|
|||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
all_self_attentions = all_self_attentions + (
|
||||
outputs[2 if use_cache else 1],
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
@ -756,7 +798,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
super().__init__(config)
|
||||
assert (
|
||||
config.bf16 + config.fp16 + config.fp32 <= 1
|
||||
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
|
||||
), 'Only one of "bf16", "fp16", "fp32" can be true'
|
||||
|
||||
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
|
||||
|
||||
|
@ -764,27 +806,35 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
if SUPPORT_BF16:
|
||||
logger.warn(
|
||||
"The model is automatically converting to bf16 for faster inference. "
|
||||
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
|
||||
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
config.bf16 = True
|
||||
elif SUPPORT_FP16:
|
||||
logger.warn(
|
||||
"The model is automatically converting to fp16 for faster inference. "
|
||||
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
|
||||
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
config.fp16 = True
|
||||
else:
|
||||
config.fp32 = True
|
||||
|
||||
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
|
||||
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
|
||||
logger.warn(
|
||||
'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
|
||||
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
|
||||
logger.warn(
|
||||
"Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster"
|
||||
)
|
||||
if config.fp32:
|
||||
if SUPPORT_BF16:
|
||||
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
||||
logger.warn(
|
||||
'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
elif SUPPORT_FP16:
|
||||
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
||||
logger.warn(
|
||||
'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
|
||||
self.transformer = QWenModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
@ -845,7 +895,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
@ -887,7 +936,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
# )
|
||||
# loss.backward()
|
||||
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -904,7 +952,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
def _reorder_cache(
|
||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||
) -> Tuple[Tuple[torch.Tensor]]:
|
||||
|
||||
return tuple(
|
||||
tuple(
|
||||
past_state.index_select(0, beam_idx.to(past_state.device))
|
||||
|
@ -924,10 +971,14 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config: Optional[GenerationConfig] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[str, HistoryType]:
|
||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
||||
generation_config = (
|
||||
generation_config
|
||||
if generation_config is not None
|
||||
else self.generation_config
|
||||
)
|
||||
|
||||
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
|
||||
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
||||
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
|
||||
if history is None:
|
||||
history = []
|
||||
else:
|
||||
|
@ -937,7 +988,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
if stop_words_ids is None:
|
||||
stop_words_ids = []
|
||||
|
||||
max_window_size = kwargs.get('max_window_size', None)
|
||||
max_window_size = kwargs.get("max_window_size", None)
|
||||
if max_window_size is None:
|
||||
max_window_size = generation_config.max_window_size
|
||||
raw_text, context_tokens = make_context(
|
||||
|
@ -949,9 +1000,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
chat_format=generation_config.chat_format,
|
||||
)
|
||||
|
||||
stop_words_ids.extend(get_stop_words_ids(
|
||||
generation_config.chat_format, tokenizer
|
||||
))
|
||||
stop_words_ids.extend(
|
||||
get_stop_words_ids(generation_config.chat_format, tokenizer)
|
||||
)
|
||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||
outputs = self.generate(
|
||||
input_ids,
|
||||
|
@ -968,7 +1019,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
context_length=len(context_tokens),
|
||||
chat_format=generation_config.chat_format,
|
||||
verbose=False,
|
||||
errors='replace'
|
||||
errors="replace",
|
||||
)
|
||||
|
||||
# as history is a copy of the user inputs,
|
||||
|
@ -990,14 +1041,18 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config: Optional[GenerationConfig] = None,
|
||||
**kwargs,
|
||||
) -> Generator[str, Any, None]:
|
||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
||||
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
||||
generation_config = (
|
||||
generation_config
|
||||
if generation_config is not None
|
||||
else self.generation_config
|
||||
)
|
||||
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
|
||||
if history is None:
|
||||
history = []
|
||||
if stop_words_ids is None:
|
||||
stop_words_ids = []
|
||||
|
||||
max_window_size = kwargs.get('max_window_size', None)
|
||||
max_window_size = kwargs.get("max_window_size", None)
|
||||
if max_window_size is None:
|
||||
max_window_size = generation_config.max_window_size
|
||||
raw_text, context_tokens = make_context(
|
||||
|
@ -1009,9 +1064,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
chat_format=generation_config.chat_format,
|
||||
)
|
||||
|
||||
stop_words_ids.extend(get_stop_words_ids(
|
||||
generation_config.chat_format, tokenizer
|
||||
))
|
||||
stop_words_ids.extend(
|
||||
get_stop_words_ids(generation_config.chat_format, tokenizer)
|
||||
)
|
||||
if stop_words_ids is not None:
|
||||
stop_words_logits_processor = StopWordsLogitsProcessor(
|
||||
stop_words_ids=stop_words_ids,
|
||||
|
@ -1023,10 +1078,16 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
logits_processor.append(stop_words_logits_processor)
|
||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||
|
||||
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
|
||||
from transformers_stream_generator.main import (
|
||||
NewGenerationMixin,
|
||||
StreamGenerationConfig,
|
||||
)
|
||||
|
||||
self.__class__.generate_stream = NewGenerationMixin.generate
|
||||
self.__class__.sample_stream = NewGenerationMixin.sample_stream
|
||||
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
|
||||
stream_config = StreamGenerationConfig(
|
||||
**generation_config.to_dict(), do_stream=True
|
||||
)
|
||||
|
||||
def stream_generator():
|
||||
outputs = []
|
||||
|
@ -1036,9 +1097,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config=stream_config,
|
||||
logits_processor=logits_processor,
|
||||
seed=-1,
|
||||
**kwargs):
|
||||
**kwargs,
|
||||
):
|
||||
outputs.append(token.item())
|
||||
yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
|
||||
yield tokenizer.decode(
|
||||
outputs, skip_special_tokens=True, errors="ignore"
|
||||
)
|
||||
|
||||
return stream_generator()
|
||||
|
||||
|
@ -1056,7 +1120,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
||||
generation_config = (
|
||||
generation_config
|
||||
if generation_config is not None
|
||||
else self.generation_config
|
||||
)
|
||||
|
||||
# Process stop_words_ids.
|
||||
stop_words_ids = kwargs.pop("stop_words_ids", None)
|
||||
|
@ -1093,7 +1161,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
prefix_allowed_tokens_fn: Optional[
|
||||
Callable[[int, torch.Tensor], List[int]]
|
||||
] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
|
@ -1101,7 +1171,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
|
||||
if synced_gpus is None:
|
||||
synced_gpus = False
|
||||
|
||||
|
@ -1114,8 +1183,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
# two conditions must be met
|
||||
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
||||
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
||||
self.generation_config
|
||||
if (
|
||||
self.generation_config._from_model_config
|
||||
and self.generation_config._original_object_hash
|
||||
== hash(self.generation_config)
|
||||
):
|
||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config:
|
||||
|
@ -1129,15 +1200,26 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config = self.generation_config
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
model_kwargs = generation_config.update(
|
||||
**kwargs
|
||||
) # All unused kwargs must be model kwargs
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
logits_processor = (
|
||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
)
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if (
|
||||
generation_config.pad_token_id is None
|
||||
and generation_config.eos_token_id is not None
|
||||
):
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
|
@ -1146,7 +1228,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
logger.warning(
|
||||
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
|
||||
)
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
# 3. Define model inputs
|
||||
|
@ -1169,12 +1253,22 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
else:
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
accepts_attention_mask = "attention_mask" in set(
|
||||
inspect.signature(self.forward).parameters.keys()
|
||||
)
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
||||
if (
|
||||
model_kwargs.get("attention_mask", None) is None
|
||||
and requires_attention_mask
|
||||
and accepts_attention_mask
|
||||
):
|
||||
model_kwargs[
|
||||
"attention_mask"
|
||||
] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
|
||||
# decoder-only models should use left-padding for generation
|
||||
|
@ -1184,7 +1278,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
if (
|
||||
generation_config.pad_token_id is not None
|
||||
and len(inputs_tensor.shape) == 2
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
|
||||
> 0
|
||||
):
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||
|
@ -1209,14 +1304,21 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
device=inputs_tensor.device,
|
||||
)
|
||||
else:
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
has_default_max_length = (
|
||||
kwargs.get("max_length") is None
|
||||
and generation_config.max_length is not None
|
||||
)
|
||||
if generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length and generation_config.max_length is not None:
|
||||
logger.warning(
|
||||
|
@ -1225,8 +1327,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_length
|
||||
)
|
||||
self._validate_generated_length(
|
||||
generation_config, input_ids_length, has_default_max_length
|
||||
)
|
||||
|
||||
# 7. determine generation mode
|
||||
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
||||
|
@ -1291,7 +1397,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
**model_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def sample_base(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
|
@ -1309,10 +1414,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
logits_processor = (
|
||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
)
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
# if max_length is not None:
|
||||
# warnings.warn(
|
||||
# "`max_length` is deprecated in this function, use"
|
||||
|
@ -1320,18 +1430,40 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
# UserWarning,
|
||||
# )
|
||||
# stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
logits_warper = (
|
||||
logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
)
|
||||
pad_token_id = (
|
||||
pad_token_id
|
||||
if pad_token_id is not None
|
||||
else self.generation_config.pad_token_id
|
||||
)
|
||||
eos_token_id = (
|
||||
eos_token_id
|
||||
if eos_token_id is not None
|
||||
else self.generation_config.eos_token_id
|
||||
)
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(input_ids.device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
output_scores = (
|
||||
output_scores
|
||||
if output_scores is not None
|
||||
else self.generation_config.output_scores
|
||||
)
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.generation_config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.generation_config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate
|
||||
|
@ -1341,19 +1473,33 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_attentions = (
|
||||
model_kwargs["encoder_outputs"].get("attentions")
|
||||
if output_attentions
|
||||
else None
|
||||
)
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
model_kwargs["encoder_outputs"].get("hidden_states")
|
||||
if output_hidden_states
|
||||
else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
# auto-regressive generation
|
||||
|
@ -1394,7 +1540,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
scores += (next_token_scores,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
(outputs.decoder_attentions,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
@ -1413,8 +1561,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
@ -1427,7 +1579,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
|
||||
# stop when each sentence is finished
|
||||
|
@ -1446,7 +1600,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
|
||||
return input_ids
|
||||
|
||||
|
||||
# def backward(
|
||||
# self,
|
||||
# tokenizer,
|
||||
|
@ -1539,7 +1692,7 @@ def _rotate_half(x):
|
|||
|
||||
|
||||
def apply_rotary_pos_emb(t, freqs):
|
||||
""" Apply rotary embedding to the first rotary_dim of the iput
|
||||
"""Apply rotary embedding to the first rotary_dim of the iput
|
||||
|
||||
Arguments:
|
||||
t (tensor(batch_size, seq_len, n_head, head_dim)):
|
||||
|
|
Loading…
Reference in New Issue