diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 959b650..d15aefc 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -38,7 +38,9 @@ from torch import nn SUPPORT_CUDA = torch.cuda.is_available() SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 -SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 +SUPPORT_TORCH2 = ( + hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2 +) from configuration_qwen import QWenConfig @@ -70,6 +72,7 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo apply_rotary_emb_func = None rms_norm = None + def quantize_cache_v(fdata, bits, qmax, qmin): # b, s, head, h-dim->b, head, s, h-dim qtype = torch.uint8 @@ -85,17 +88,19 @@ def quantize_cache_v(fdata, bits, qmax, qmin): qmin = qmin.to(device) scale = (fmax - fmin) / (qmax - qmin) zero = qmin - fmin / scale - scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() - zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() + zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() # Quantize res_data = fdata / scale + zero qdata = torch.clamp(res_data, qmin, qmax).to(qtype) return qdata.contiguous(), scale, zero + def dequantize_cache_torch(qdata, scale, zero): data = scale * (qdata - zero) return data + class QWenAttention(nn.Module): def __init__(self, config): super().__init__() @@ -138,12 +143,20 @@ class QWenAttention(nn.Module): self.register_buffer("logn_tensor", logn_tensor, persistent=False) self.attn_dropout = nn.Dropout(config.attn_dropout_prob) - self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False - self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False - self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False + self.softmax_in_fp32 = ( + config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False + ) + self.use_cache_quantization = ( + config.use_cache_quantization + if hasattr(config, "use_cache_quantization") + else False + ) + self.use_cache_kernel = ( + config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False + ) cache_dtype = torch.float if self.bf16: - cache_dtype=torch.bfloat16 + cache_dtype = torch.bfloat16 elif config.fp16: cache_dtype = torch.float16 self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) @@ -152,19 +165,25 @@ class QWenAttention(nn.Module): if config.use_cache_quantization and config.use_cache_kernel: # pre check if the support files existing module_root = pathlib.Path(__file__).parent - src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") - if any(not (module_root/src).is_file() for src in src_files): + src_files = ( + "cache_autogptq_cuda_256.cpp", + "cache_autogptq_cuda_kernel_256.cu", + ) + if any(not (module_root / src).is_file() for src in src_files): warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") self.cache_kernels = None else: try: from .cpp_kernels import cache_autogptq_cuda_256 + self.cache_kernels = cache_autogptq_cuda_256 except ImportError: warnings.warn("Failed to import KV cache kernels.") self.cache_kernels = None - def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): + def _attn( + self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None + ): device = query.device if self.use_cache_quantization: qk, qk_scale, qk_zero = key @@ -172,11 +191,18 @@ class QWenAttention(nn.Module): shape = query.shape[:-1] + (qk.shape[-2],) attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) self.cache_kernels.vecquant8matmul_batched_faster_old( - query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), + query.contiguous() + if query.dtype == torch.float16 + else query.to(torch.float16).contiguous(), qk.transpose(-1, -2).contiguous(), attn_weights, - qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), - qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) + qk_scale.contiguous() + if qk_scale.dtype == torch.float16 + else qk_scale.to(torch.float16).contiguous(), + qk_zero.contiguous() + if qk_zero.dtype == torch.float16 + else qk_zero.to(torch.float16).contiguous(), + ) # attn_weights = attn_weights.to(query.dtype).contiguous() else: key = dequantize_cache_torch(qk, qk_scale, qk_zero) @@ -189,7 +215,7 @@ class QWenAttention(nn.Module): size_temp = value[0].size(-1) else: size_temp = value.size(-1) - attn_weights = attn_weights / (size_temp ** 0.5) + attn_weights = attn_weights / (size_temp**0.5) mask_value = torch.finfo(attn_weights.dtype).min if causal_mask is not None: @@ -217,11 +243,18 @@ class QWenAttention(nn.Module): shape = attn_weights.shape[:-1] + (query.shape[-1],) attn_output = torch.zeros(shape, dtype=torch.float16, device=device) self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( - attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), + attn_weights.contiguous() + if attn_weights.dtype == torch.float16 + else attn_weights.to(torch.float16).contiguous(), qv.contiguous(), # dtype: int32 attn_output, - qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), - qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) + qv_scale.contiguous() + if qv_scale.dtype == torch.float16 + else qv_scale.to(torch.float16).contiguous(), + qv_zero.contiguous() + if qv_zero.dtype == torch.float16 + else qv_zero.to(torch.float16).contiguous(), + ) if attn_output.dtype != query.dtype: attn_output = attn_output.to(query.dtype) attn_weights = attn_weights.to(query.dtype) @@ -283,21 +316,26 @@ class QWenAttention(nn.Module): rotary_pos_emb = (rotary_pos_emb,) * 2 q_pos_emb, k_pos_emb = rotary_pos_emb # Slice the pos emb for current inference - query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] - key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] + query_list += [ + apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb) + ] + key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)] query = torch.cat(query_list, dim=0) key = torch.cat(key_list, dim=0) if self.use_cache_quantization: - key = quantize_cache_v(key.permute(0, 2, 1, 3), - bits=8, - qmin=self.cache_qmin, - qmax=self.cache_qmax) - value = quantize_cache_v(value.permute(0, 2, 1, 3), - bits=8, - qmin=self.cache_qmin, - qmax=self.cache_qmax) - + key = quantize_cache_v( + key.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax, + ) + value = quantize_cache_v( + value.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax, + ) if layer_past is not None: past_key, past_value = layer_past[0], layer_past[1] @@ -305,12 +343,16 @@ class QWenAttention(nn.Module): # use_cache_quantization: # present=((q_key,key_scale,key_zero_point), # (q_value,value_scale,value_zero_point)) - key = (torch.cat((past_key[0], key[0]), dim=2), - torch.cat((past_key[1], key[1]), dim=2), - torch.cat((past_key[2], key[2]), dim=2)) - value = (torch.cat((past_value[0], value[0]), dim=2), - torch.cat((past_value[1], value[1]), dim=2), - torch.cat((past_value[2], value[2]), dim=2)) + key = ( + torch.cat((past_key[0], key[0]), dim=2), + torch.cat((past_key[1], key[1]), dim=2), + torch.cat((past_key[2], key[2]), dim=2), + ) + value = ( + torch.cat((past_value[0], value[0]), dim=2), + torch.cat((past_value[1], value[1]), dim=2), + torch.cat((past_value[2], value[2]), dim=2), + ) else: # not use_cache_quantization: # present=(key,value) @@ -347,11 +389,11 @@ class QWenAttention(nn.Module): if not self.use_cache_quantization and SUPPORT_TORCH2: if attention_mask is not None: - attention_mask = attention_mask.expand( - -1, -1, causal_mask.size(2), -1 - ) + attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1) if causal_mask is not None: - attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) + attention_mask = attention_mask.masked_fill( + ~causal_mask, torch.finfo(query.dtype).min + ) else: attention_mask = causal_mask attn_output = F.scaled_dot_product_attention( @@ -362,16 +404,16 @@ class QWenAttention(nn.Module): attn_output, attn_weight = self._attn( query, key, value, causal_mask, attention_mask, head_mask ) - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) + context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(context_layer) outputs = (attn_output, present) if output_attentions: if not self.use_cache_quantization and SUPPORT_TORCH2: - raise ValueError("Cannot output attentions while using scaled_dot_product_attention") + raise ValueError( + "Cannot output attentions while using scaled_dot_product_attention" + ) else: outputs += (attn_weight,) @@ -507,7 +549,11 @@ class QWenModel(QWenPreTrainedModel): self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.embed_dim = config.hidden_size - self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False + self.use_cache_quantization = ( + self.config.use_cache_quantization + if hasattr(self.config, "use_cache_quantization") + else False + ) self.gradient_checkpointing = False self.use_dynamic_ntk = config.use_dynamic_ntk @@ -521,25 +567,14 @@ class QWenModel(QWenPreTrainedModel): self.rotary_ndims = None else: assert config.rotary_pct < 1 - self.rotary_ndims = int( - config.kv_channels * config.rotary_pct - ) - dim = ( - self.rotary_ndims - if self.rotary_ndims is not None - else config.kv_channels - ) + self.rotary_ndims = int(config.kv_channels * config.rotary_pct) + dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) self.is_fp32 = not (config.bf16 or config.fp16) self.h = nn.ModuleList( - [ - QWenBlock( - config - ) - for i in range(config.num_hidden_layers) - ] + [QWenBlock(config) for i in range(config.num_hidden_layers)] ) self.ln_f = RMSNorm( self.embed_dim, @@ -659,7 +694,12 @@ class QWenModel(QWenPreTrainedModel): else: ntk_alpha_list = [] if attention_mask is not None and kv_seq_len > self.seq_length: - true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) + true_seq_lens = ( + attention_mask.squeeze(1) + .squeeze(1) + .eq(0) + .sum(dim=-1, dtype=torch.int32) + ) for i in range(hidden_states.size()[0]): true_seq_len = true_seq_lens[i].item() ntk_alpha = self.get_ntk_alpha(true_seq_len) @@ -669,7 +709,8 @@ class QWenModel(QWenPreTrainedModel): ntk_alpha_list.append(ntk_alpha) self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list rotary_pos_emb_list = [ - self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list + self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) + for ntk_alpha in ntk_alpha_list ] hidden_states = self.drop(hidden_states) @@ -686,7 +727,6 @@ class QWenModel(QWenPreTrainedModel): all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -727,7 +767,9 @@ class QWenModel(QWenPreTrainedModel): presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + ( + outputs[2 if use_cache else 1], + ) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) @@ -756,7 +798,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): super().__init__(config) assert ( config.bf16 + config.fp16 + config.fp32 <= 1 - ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" + ), 'Only one of "bf16", "fp16", "fp32" can be true' autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 @@ -764,27 +806,35 @@ class QWenLMHeadModel(QWenPreTrainedModel): if SUPPORT_BF16: logger.warn( "The model is automatically converting to bf16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." + 'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".' ) config.bf16 = True elif SUPPORT_FP16: logger.warn( "The model is automatically converting to fp16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." + 'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".' ) config.fp16 = True else: config.fp32 = True if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: - logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") + logger.warn( + 'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".' + ) if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: - logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") + logger.warn( + "Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster" + ) if config.fp32: if SUPPORT_BF16: - logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") + logger.warn( + 'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".' + ) elif SUPPORT_FP16: - logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") + logger.warn( + 'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".' + ) self.transformer = QWenModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -845,7 +895,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: - return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) @@ -887,7 +936,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): # ) # loss.backward() - if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -904,7 +952,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): def _reorder_cache( past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: - return tuple( tuple( past_state.index_select(0, beam_idx.to(past_state.device)) @@ -924,10 +971,14 @@ class QWenLMHeadModel(QWenPreTrainedModel): generation_config: Optional[GenerationConfig] = None, **kwargs, ) -> Tuple[str, HistoryType]: - generation_config = generation_config if generation_config is not None else self.generation_config + generation_config = ( + generation_config + if generation_config is not None + else self.generation_config + ) assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT if history is None: history = [] else: @@ -937,7 +988,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): if stop_words_ids is None: stop_words_ids = [] - max_window_size = kwargs.get('max_window_size', None) + max_window_size = kwargs.get("max_window_size", None) if max_window_size is None: max_window_size = generation_config.max_window_size raw_text, context_tokens = make_context( @@ -949,18 +1000,18 @@ class QWenLMHeadModel(QWenPreTrainedModel): chat_format=generation_config.chat_format, ) - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) + stop_words_ids.extend( + get_stop_words_ids(generation_config.chat_format, tokenizer) + ) input_ids = torch.tensor([context_tokens]).to(self.device) outputs = self.generate( - input_ids, - stop_words_ids=stop_words_ids, - return_dict_in_generate=False, - generation_config=generation_config, - **kwargs, - ) - + input_ids, + stop_words_ids=stop_words_ids, + return_dict_in_generate=False, + generation_config=generation_config, + **kwargs, + ) + response = decode_tokens( outputs[0], tokenizer, @@ -968,7 +1019,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): context_length=len(context_tokens), chat_format=generation_config.chat_format, verbose=False, - errors='replace' + errors="replace", ) # as history is a copy of the user inputs, @@ -980,24 +1031,28 @@ class QWenLMHeadModel(QWenPreTrainedModel): return response, history def chat_stream( - self, - tokenizer: PreTrainedTokenizer, - query: str, - history: Optional[HistoryType], - system: str = "You are a helpful assistant.", - stop_words_ids: Optional[List[List[int]]] = None, - logits_processor: Optional[LogitsProcessorList] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, + self, + tokenizer: PreTrainedTokenizer, + query: str, + history: Optional[HistoryType], + system: str = "You are a helpful assistant.", + stop_words_ids: Optional[List[List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, ) -> Generator[str, Any, None]: - generation_config = generation_config if generation_config is not None else self.generation_config - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + generation_config = ( + generation_config + if generation_config is not None + else self.generation_config + ) + assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT if history is None: history = [] if stop_words_ids is None: stop_words_ids = [] - max_window_size = kwargs.get('max_window_size', None) + max_window_size = kwargs.get("max_window_size", None) if max_window_size is None: max_window_size = generation_config.max_window_size raw_text, context_tokens = make_context( @@ -1009,9 +1064,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): chat_format=generation_config.chat_format, ) - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) + stop_words_ids.extend( + get_stop_words_ids(generation_config.chat_format, tokenizer) + ) if stop_words_ids is not None: stop_words_logits_processor = StopWordsLogitsProcessor( stop_words_ids=stop_words_ids, @@ -1023,22 +1078,31 @@ class QWenLMHeadModel(QWenPreTrainedModel): logits_processor.append(stop_words_logits_processor) input_ids = torch.tensor([context_tokens]).to(self.device) - from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig + from transformers_stream_generator.main import ( + NewGenerationMixin, + StreamGenerationConfig, + ) + self.__class__.generate_stream = NewGenerationMixin.generate self.__class__.sample_stream = NewGenerationMixin.sample_stream - stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) + stream_config = StreamGenerationConfig( + **generation_config.to_dict(), do_stream=True + ) def stream_generator(): outputs = [] for token in self.generate_stream( - input_ids, - return_dict_in_generate=False, - generation_config=stream_config, - logits_processor=logits_processor, - seed=-1, - **kwargs): + input_ids, + return_dict_in_generate=False, + generation_config=stream_config, + logits_processor=logits_processor, + seed=-1, + **kwargs, + ): outputs.append(token.item()) - yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') + yield tokenizer.decode( + outputs, skip_special_tokens=True, errors="ignore" + ) return stream_generator() @@ -1056,7 +1120,11 @@ class QWenLMHeadModel(QWenPreTrainedModel): streamer: Optional["BaseStreamer"] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: - generation_config = generation_config if generation_config is not None else self.generation_config + generation_config = ( + generation_config + if generation_config is not None + else self.generation_config + ) # Process stop_words_ids. stop_words_ids = kwargs.pop("stop_words_ids", None) @@ -1093,7 +1161,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, @@ -1101,9 +1171,8 @@ class QWenLMHeadModel(QWenPreTrainedModel): negative_prompt_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: - if synced_gpus is None: - synced_gpus = False + synced_gpus = False # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() @@ -1114,8 +1183,10 @@ class QWenLMHeadModel(QWenPreTrainedModel): # two conditions must be met # 1) the generation config must have been created from the model config (`_from_model_config` field); # 2) the generation config must have seen no modification since its creation (the hash is the same). - if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( - self.generation_config + if ( + self.generation_config._from_model_config + and self.generation_config._original_object_hash + == hash(self.generation_config) ): new_generation_config = GenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: @@ -1129,15 +1200,26 @@ class QWenLMHeadModel(QWenPreTrainedModel): generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + model_kwargs = generation_config.update( + **kwargs + ) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + if ( + generation_config.pad_token_id is None + and generation_config.eos_token_id is not None + ): if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " @@ -1146,7 +1228,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." + ) generation_config.pad_token_id = eos_token_id # 3. Define model inputs @@ -1169,12 +1253,22 @@ class QWenLMHeadModel(QWenPreTrainedModel): else: model_kwargs["use_cache"] = generation_config.use_cache - accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys() + ) requires_attention_mask = "encoder_outputs" not in model_kwargs - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + if ( + model_kwargs.get("attention_mask", None) is None + and requires_attention_mask + and accepts_attention_mask + ): + model_kwargs[ + "attention_mask" + ] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, ) # decoder-only models should use left-padding for generation @@ -1184,7 +1278,8 @@ class QWenLMHeadModel(QWenPreTrainedModel): if ( generation_config.pad_token_id is not None and len(inputs_tensor.shape) == 2 - and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) + > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " @@ -1209,14 +1304,21 @@ class QWenLMHeadModel(QWenPreTrainedModel): device=inputs_tensor.device, ) else: - input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + input_ids = ( + inputs_tensor + if model_input_name == "input_ids" + else model_kwargs.pop("input_ids") + ) if streamer is not None: streamer.put(input_ids.cpu()) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) if generation_config.max_new_tokens is not None: if not has_default_max_length and generation_config.max_length is not None: logger.warning( @@ -1225,8 +1327,12 @@ class QWenLMHeadModel(QWenPreTrainedModel): "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_length - self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_length + ) + self._validate_generated_length( + generation_config, input_ids_length, has_default_max_length + ) # 7. determine generation mode generation_mode = self._get_generation_mode(generation_config, assistant_model) @@ -1264,7 +1370,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): generation_config=generation_config, stopping_criteria=stopping_criteria ) # 10. go into different generation modes - + # 11. prepare logits warper logits_warper = self._get_logits_warper(generation_config) @@ -1291,7 +1397,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): **model_kwargs, ) - def sample_base( self, input_ids: torch.LongTensor, @@ -1309,10 +1414,15 @@ class QWenLMHeadModel(QWenPreTrainedModel): streamer: Optional["BaseStreamer"] = None, **model_kwargs, ): - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) # if max_length is not None: # warnings.warn( # "`max_length` is deprecated in this function, use" @@ -1320,18 +1430,40 @@ class QWenLMHeadModel(QWenPreTrainedModel): # UserWarning, # ) # stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) + if eos_token_id is not None + else None + ) + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -1341,19 +1473,33 @@ class QWenLMHeadModel(QWenPreTrainedModel): # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + unfinished_sequences = torch.ones( + input_ids.shape[0], dtype=torch.long, device=input_ids.device + ) this_peer_finished = False # used by synced_gpus only # auto-regressive generation @@ -1394,7 +1540,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): scores += (next_token_scores,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -1413,8 +1561,12 @@ class QWenLMHeadModel(QWenPreTrainedModel): # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -1427,7 +1579,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + next_tokens.tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) ) # stop when each sentence is finished @@ -1446,7 +1600,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): return input_ids - # def backward( # self, # tokenizer, @@ -1539,20 +1692,20 @@ def _rotate_half(x): def apply_rotary_pos_emb(t, freqs): - """ Apply rotary embedding to the first rotary_dim of the iput + """Apply rotary embedding to the first rotary_dim of the iput Arguments: t (tensor(batch_size, seq_len, n_head, head_dim)): the input embedding/hidden states freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]): - the cached cos/sin position embeddings + the cached cos/sin position embeddings """ rot_dim = freqs[0].shape[-1] cos, sin = freqs t_float = t.float() if apply_rotary_emb_func is not None and t.is_cuda: - # apply_rotary_emb in flash_attn requires cos/sin to be of - # shape (seqlen, rotary_dim / 2) and apply rotary embedding + # apply_rotary_emb in flash_attn requires cos/sin to be of + # shape (seqlen, rotary_dim / 2) and apply rotary embedding # to the first rotary_dim of the input cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2] sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]