From 063f7221772a8554b65ab29085cd302be12d331e Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 11 Jan 2024 07:00:18 +0000 Subject: [PATCH] Refine model of qwen. --- qwen/modeling_qwen.py | 67 ++++++++++++------------------------------- 1 file changed, 18 insertions(+), 49 deletions(-) diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 5f42bf6..ae8a293 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -465,12 +465,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): ) -> Union[GenerateOutput, torch.LongTensor]: generation_config = self.generation_config - # Process stop_words_ids. - stop_words_logits_processor = StopWordsLogitsProcessor( - stop_words_ids=stop_words_ids, - eos_token_id=generation_config.eos_token_id, - ) - logits_processor = LogitsProcessorList([stop_words_logits_processor]) # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() @@ -523,7 +517,12 @@ class QWenLMHeadModel(QWenPreTrainedModel): generation_config.max_length = generation_config.max_new_tokens + input_ids_length self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) - # 8. prepare distribution pre_processing samplers + + stop_words_logits_processor = StopWordsLogitsProcessor( + stop_words_ids=stop_words_ids, + eos_token_id=generation_config.eos_token_id, + ) + logits_processor = LogitsProcessorList([stop_words_logits_processor]) logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, @@ -535,16 +534,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): negative_prompt_attention_mask=None, ) - # 9. prepare stopping criteria - - stopping_criteria = StoppingCriteriaList() - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - # 10. go into different generation modes - - # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -555,42 +544,24 @@ class QWenLMHeadModel(QWenPreTrainedModel): ) # 13. run sample - return self.sample_base( - input_ids, - logits_processor=logits_processor, - logits_warper=logits_warper, - stopping_criteria=stopping_criteria, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, - output_scores=generation_config.output_scores, - streamer=streamer, - **model_kwargs, + + pad_token_id=generation_config.pad_token_id + eos_token_id=generation_config.eos_token_id + streamer=streamer + + + # init values + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=StoppingCriteriaList() ) - def sample_base( - self, - input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_scores: Optional[bool] = None, - streamer: Optional["BaseStreamer"] = None, - **model_kwargs, - ): - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_warper = self._get_logits_warper(generation_config) - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores # init attention / hidden states / scores tuples scores = None @@ -607,10 +578,10 @@ class QWenLMHeadModel(QWenPreTrainedModel): # forward pass to get next token outputs = self(**model_inputs) - next_token_logits = outputs.logits[:, -1, :] + next_token_scores = outputs.logits[:, -1, :] # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores) # sample @@ -619,8 +590,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): # finished sentences should have their next token be a padding token if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step