Format qwen model.

This commit is contained in:
Colin 2024-01-07 16:23:04 +08:00
parent 255a2ff71c
commit 611396b656
1 changed files with 309 additions and 156 deletions

View File

@ -38,7 +38,9 @@ from torch import nn
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
SUPPORT_TORCH2 = (
hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2
)
from configuration_qwen import QWenConfig
@ -70,6 +72,7 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
apply_rotary_emb_func = None
rms_norm = None
def quantize_cache_v(fdata, bits, qmax, qmin):
# b, s, head, h-dim->b, head, s, h-dim
qtype = torch.uint8
@ -85,17 +88,19 @@ def quantize_cache_v(fdata, bits, qmax, qmin):
qmin = qmin.to(device)
scale = (fmax - fmin) / (qmax - qmin)
zero = qmin - fmin / scale
scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
# Quantize
res_data = fdata / scale + zero
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
return qdata.contiguous(), scale, zero
def dequantize_cache_torch(qdata, scale, zero):
data = scale * (qdata - zero)
return data
class QWenAttention(nn.Module):
def __init__(self, config):
super().__init__()
@ -138,12 +143,20 @@ class QWenAttention(nn.Module):
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
self.softmax_in_fp32 = (
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
)
self.use_cache_quantization = (
config.use_cache_quantization
if hasattr(config, "use_cache_quantization")
else False
)
self.use_cache_kernel = (
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
)
cache_dtype = torch.float
if self.bf16:
cache_dtype=torch.bfloat16
cache_dtype = torch.bfloat16
elif config.fp16:
cache_dtype = torch.float16
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
@ -152,19 +165,25 @@ class QWenAttention(nn.Module):
if config.use_cache_quantization and config.use_cache_kernel:
# pre check if the support files existing
module_root = pathlib.Path(__file__).parent
src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
if any(not (module_root/src).is_file() for src in src_files):
src_files = (
"cache_autogptq_cuda_256.cpp",
"cache_autogptq_cuda_kernel_256.cu",
)
if any(not (module_root / src).is_file() for src in src_files):
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
self.cache_kernels = None
else:
try:
from .cpp_kernels import cache_autogptq_cuda_256
self.cache_kernels = cache_autogptq_cuda_256
except ImportError:
warnings.warn("Failed to import KV cache kernels.")
self.cache_kernels = None
def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
def _attn(
self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None
):
device = query.device
if self.use_cache_quantization:
qk, qk_scale, qk_zero = key
@ -172,11 +191,18 @@ class QWenAttention(nn.Module):
shape = query.shape[:-1] + (qk.shape[-2],)
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
self.cache_kernels.vecquant8matmul_batched_faster_old(
query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
query.contiguous()
if query.dtype == torch.float16
else query.to(torch.float16).contiguous(),
qk.transpose(-1, -2).contiguous(),
attn_weights,
qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(),
qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous())
qk_scale.contiguous()
if qk_scale.dtype == torch.float16
else qk_scale.to(torch.float16).contiguous(),
qk_zero.contiguous()
if qk_zero.dtype == torch.float16
else qk_zero.to(torch.float16).contiguous(),
)
# attn_weights = attn_weights.to(query.dtype).contiguous()
else:
key = dequantize_cache_torch(qk, qk_scale, qk_zero)
@ -189,7 +215,7 @@ class QWenAttention(nn.Module):
size_temp = value[0].size(-1)
else:
size_temp = value.size(-1)
attn_weights = attn_weights / (size_temp ** 0.5)
attn_weights = attn_weights / (size_temp**0.5)
mask_value = torch.finfo(attn_weights.dtype).min
if causal_mask is not None:
@ -217,11 +243,18 @@ class QWenAttention(nn.Module):
shape = attn_weights.shape[:-1] + (query.shape[-1],)
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
attn_weights.contiguous()
if attn_weights.dtype == torch.float16
else attn_weights.to(torch.float16).contiguous(),
qv.contiguous(), # dtype: int32
attn_output,
qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(),
qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous())
qv_scale.contiguous()
if qv_scale.dtype == torch.float16
else qv_scale.to(torch.float16).contiguous(),
qv_zero.contiguous()
if qv_zero.dtype == torch.float16
else qv_zero.to(torch.float16).contiguous(),
)
if attn_output.dtype != query.dtype:
attn_output = attn_output.to(query.dtype)
attn_weights = attn_weights.to(query.dtype)
@ -283,21 +316,26 @@ class QWenAttention(nn.Module):
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
query_list += [
apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)
]
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
if self.use_cache_quantization:
key = quantize_cache_v(key.permute(0, 2, 1, 3),
bits=8,
qmin=self.cache_qmin,
qmax=self.cache_qmax)
value = quantize_cache_v(value.permute(0, 2, 1, 3),
bits=8,
qmin=self.cache_qmin,
qmax=self.cache_qmax)
key = quantize_cache_v(
key.permute(0, 2, 1, 3),
bits=8,
qmin=self.cache_qmin,
qmax=self.cache_qmax,
)
value = quantize_cache_v(
value.permute(0, 2, 1, 3),
bits=8,
qmin=self.cache_qmin,
qmax=self.cache_qmax,
)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
@ -305,12 +343,16 @@ class QWenAttention(nn.Module):
# use_cache_quantization:
# present=((q_key,key_scale,key_zero_point),
# (q_value,value_scale,value_zero_point))
key = (torch.cat((past_key[0], key[0]), dim=2),
torch.cat((past_key[1], key[1]), dim=2),
torch.cat((past_key[2], key[2]), dim=2))
value = (torch.cat((past_value[0], value[0]), dim=2),
torch.cat((past_value[1], value[1]), dim=2),
torch.cat((past_value[2], value[2]), dim=2))
key = (
torch.cat((past_key[0], key[0]), dim=2),
torch.cat((past_key[1], key[1]), dim=2),
torch.cat((past_key[2], key[2]), dim=2),
)
value = (
torch.cat((past_value[0], value[0]), dim=2),
torch.cat((past_value[1], value[1]), dim=2),
torch.cat((past_value[2], value[2]), dim=2),
)
else:
# not use_cache_quantization:
# present=(key,value)
@ -347,11 +389,11 @@ class QWenAttention(nn.Module):
if not self.use_cache_quantization and SUPPORT_TORCH2:
if attention_mask is not None:
attention_mask = attention_mask.expand(
-1, -1, causal_mask.size(2), -1
)
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
if causal_mask is not None:
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
attention_mask = attention_mask.masked_fill(
~causal_mask, torch.finfo(query.dtype).min
)
else:
attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention(
@ -362,16 +404,16 @@ class QWenAttention(nn.Module):
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer)
outputs = (attn_output, present)
if output_attentions:
if not self.use_cache_quantization and SUPPORT_TORCH2:
raise ValueError("Cannot output attentions while using scaled_dot_product_attention")
raise ValueError(
"Cannot output attentions while using scaled_dot_product_attention"
)
else:
outputs += (attn_weight,)
@ -507,7 +549,11 @@ class QWenModel(QWenPreTrainedModel):
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size
self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False
self.use_cache_quantization = (
self.config.use_cache_quantization
if hasattr(self.config, "use_cache_quantization")
else False
)
self.gradient_checkpointing = False
self.use_dynamic_ntk = config.use_dynamic_ntk
@ -521,25 +567,14 @@ class QWenModel(QWenPreTrainedModel):
self.rotary_ndims = None
else:
assert config.rotary_pct < 1
self.rotary_ndims = int(
config.kv_channels * config.rotary_pct
)
dim = (
self.rotary_ndims
if self.rotary_ndims is not None
else config.kv_channels
)
self.rotary_ndims = int(config.kv_channels * config.rotary_pct)
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.is_fp32 = not (config.bf16 or config.fp16)
self.h = nn.ModuleList(
[
QWenBlock(
config
)
for i in range(config.num_hidden_layers)
]
[QWenBlock(config) for i in range(config.num_hidden_layers)]
)
self.ln_f = RMSNorm(
self.embed_dim,
@ -659,7 +694,12 @@ class QWenModel(QWenPreTrainedModel):
else:
ntk_alpha_list = []
if attention_mask is not None and kv_seq_len > self.seq_length:
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
true_seq_lens = (
attention_mask.squeeze(1)
.squeeze(1)
.eq(0)
.sum(dim=-1, dtype=torch.int32)
)
for i in range(hidden_states.size()[0]):
true_seq_len = true_seq_lens[i].item()
ntk_alpha = self.get_ntk_alpha(true_seq_len)
@ -669,7 +709,8 @@ class QWenModel(QWenPreTrainedModel):
ntk_alpha_list.append(ntk_alpha)
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
rotary_pos_emb_list = [
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
for ntk_alpha in ntk_alpha_list
]
hidden_states = self.drop(hidden_states)
@ -686,7 +727,6 @@ class QWenModel(QWenPreTrainedModel):
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -727,7 +767,9 @@ class QWenModel(QWenPreTrainedModel):
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
@ -756,7 +798,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
), 'Only one of "bf16", "fp16", "fp32" can be true'
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
@ -764,27 +806,35 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
logger.warn(
'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".'
)
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
logger.warn(
"Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster"
)
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
logger.warn(
'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".'
)
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
logger.warn(
'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".'
)
self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@ -845,7 +895,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
@ -887,7 +936,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# )
# loss.backward()
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
@ -904,7 +952,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
@ -924,10 +971,14 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Tuple[str, HistoryType]:
generation_config = generation_config if generation_config is not None else self.generation_config
generation_config = (
generation_config
if generation_config is not None
else self.generation_config
)
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
else:
@ -937,7 +988,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
max_window_size = kwargs.get("max_window_size", None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
@ -949,18 +1000,18 @@ class QWenLMHeadModel(QWenPreTrainedModel):
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
generation_config.chat_format, tokenizer
))
stop_words_ids.extend(
get_stop_words_ids(generation_config.chat_format, tokenizer)
)
input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate(
input_ids,
stop_words_ids=stop_words_ids,
return_dict_in_generate=False,
generation_config=generation_config,
**kwargs,
)
input_ids,
stop_words_ids=stop_words_ids,
return_dict_in_generate=False,
generation_config=generation_config,
**kwargs,
)
response = decode_tokens(
outputs[0],
tokenizer,
@ -968,7 +1019,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
context_length=len(context_tokens),
chat_format=generation_config.chat_format,
verbose=False,
errors='replace'
errors="replace",
)
# as history is a copy of the user inputs,
@ -980,24 +1031,28 @@ class QWenLMHeadModel(QWenPreTrainedModel):
return response, history
def chat_stream(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
stop_words_ids: Optional[List[List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
stop_words_ids: Optional[List[List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Generator[str, Any, None]:
generation_config = generation_config if generation_config is not None else self.generation_config
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
generation_config = (
generation_config
if generation_config is not None
else self.generation_config
)
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
max_window_size = kwargs.get("max_window_size", None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
@ -1009,9 +1064,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
generation_config.chat_format, tokenizer
))
stop_words_ids.extend(
get_stop_words_ids(generation_config.chat_format, tokenizer)
)
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
@ -1023,22 +1078,31 @@ class QWenLMHeadModel(QWenPreTrainedModel):
logits_processor.append(stop_words_logits_processor)
input_ids = torch.tensor([context_tokens]).to(self.device)
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
from transformers_stream_generator.main import (
NewGenerationMixin,
StreamGenerationConfig,
)
self.__class__.generate_stream = NewGenerationMixin.generate
self.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
stream_config = StreamGenerationConfig(
**generation_config.to_dict(), do_stream=True
)
def stream_generator():
outputs = []
for token in self.generate_stream(
input_ids,
return_dict_in_generate=False,
generation_config=stream_config,
logits_processor=logits_processor,
seed=-1,
**kwargs):
input_ids,
return_dict_in_generate=False,
generation_config=stream_config,
logits_processor=logits_processor,
seed=-1,
**kwargs,
):
outputs.append(token.item())
yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
yield tokenizer.decode(
outputs, skip_special_tokens=True, errors="ignore"
)
return stream_generator()
@ -1056,7 +1120,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = generation_config if generation_config is not None else self.generation_config
generation_config = (
generation_config
if generation_config is not None
else self.generation_config
)
# Process stop_words_ids.
stop_words_ids = kwargs.pop("stop_words_ids", None)
@ -1093,7 +1161,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
@ -1101,9 +1171,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
if synced_gpus is None:
synced_gpus = False
synced_gpus = False
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
@ -1114,8 +1183,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
self.generation_config
if (
self.generation_config._from_model_config
and self.generation_config._original_object_hash
== hash(self.generation_config)
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
@ -1129,15 +1200,26 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
model_kwargs = generation_config.update(
**kwargs
) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if (
generation_config.pad_token_id is None
and generation_config.eos_token_id is not None
):
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
@ -1146,7 +1228,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
logger.warning(
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
)
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
@ -1169,12 +1253,22 @@ class QWenLMHeadModel(QWenPreTrainedModel):
else:
model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
accepts_attention_mask = "attention_mask" in set(
inspect.signature(self.forward).parameters.keys()
)
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
if (
model_kwargs.get("attention_mask", None) is None
and requires_attention_mask
and accepts_attention_mask
):
model_kwargs[
"attention_mask"
] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
)
# decoder-only models should use left-padding for generation
@ -1184,7 +1278,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if (
generation_config.pad_token_id is not None
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
> 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
@ -1209,14 +1304,21 @@ class QWenLMHeadModel(QWenPreTrainedModel):
device=inputs_tensor.device,
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
if streamer is not None:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning(
@ -1225,8 +1327,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_length
)
self._validate_generated_length(
generation_config, input_ids_length, has_default_max_length
)
# 7. determine generation mode
generation_mode = self._get_generation_mode(generation_config, assistant_model)
@ -1264,7 +1370,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config=generation_config, stopping_criteria=stopping_criteria
)
# 10. go into different generation modes
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
@ -1291,7 +1397,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
**model_kwargs,
)
def sample_base(
self,
input_ids: torch.LongTensor,
@ -1309,10 +1414,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
):
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
# if max_length is not None:
# warnings.warn(
# "`max_length` is deprecated in this function, use"
@ -1320,18 +1430,40 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# UserWarning,
# )
# stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
logits_warper = (
logits_warper if logits_warper is not None else LogitsProcessorList()
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.generation_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.generation_config.eos_token_id
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
eos_token_id_tensor = (
torch.tensor(eos_token_id).to(input_ids.device)
if eos_token_id is not None
else None
)
output_scores = (
output_scores
if output_scores is not None
else self.generation_config.output_scores
)
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
output_attentions
if output_attentions is not None
else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
output_hidden_states
if output_hidden_states is not None
else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
@ -1341,19 +1473,33 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_attentions = (
model_kwargs["encoder_outputs"].get("attentions")
if output_attentions
else None
)
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
model_kwargs["encoder_outputs"].get("hidden_states")
if output_hidden_states
else None
)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
@ -1394,7 +1540,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
(outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
@ -1413,8 +1561,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@ -1427,7 +1579,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
@ -1446,7 +1600,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
return input_ids
# def backward(
# self,
# tokenizer,
@ -1539,20 +1692,20 @@ def _rotate_half(x):
def apply_rotary_pos_emb(t, freqs):
""" Apply rotary embedding to the first rotary_dim of the iput
"""Apply rotary embedding to the first rotary_dim of the iput
Arguments:
t (tensor(batch_size, seq_len, n_head, head_dim)):
the input embedding/hidden states
freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
the cached cos/sin position embeddings
the cached cos/sin position embeddings
"""
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_float = t.float()
if apply_rotary_emb_func is not None and t.is_cuda:
# apply_rotary_emb in flash_attn requires cos/sin to be of
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
# apply_rotary_emb in flash_attn requires cos/sin to be of
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
# to the first rotary_dim of the input
cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]