Refine model of qwen.
This commit is contained in:
		
							parent
							
								
									7d7b4381f8
								
							
						
					
					
						commit
						063f722177
					
				| 
						 | 
				
			
			@ -465,12 +465,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
    ) -> Union[GenerateOutput, torch.LongTensor]:
 | 
			
		||||
        generation_config = self.generation_config
 | 
			
		||||
 | 
			
		||||
        # Process stop_words_ids.
 | 
			
		||||
        stop_words_logits_processor = StopWordsLogitsProcessor(
 | 
			
		||||
            stop_words_ids=stop_words_ids,
 | 
			
		||||
            eos_token_id=generation_config.eos_token_id,
 | 
			
		||||
        )
 | 
			
		||||
        logits_processor = LogitsProcessorList([stop_words_logits_processor])
 | 
			
		||||
 | 
			
		||||
        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
 | 
			
		||||
        self._validate_model_class()
 | 
			
		||||
| 
						 | 
				
			
			@ -523,7 +517,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
        generation_config.max_length = generation_config.max_new_tokens + input_ids_length
 | 
			
		||||
        self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
 | 
			
		||||
 | 
			
		||||
        # 8. prepare distribution pre_processing samplers
 | 
			
		||||
 | 
			
		||||
        stop_words_logits_processor = StopWordsLogitsProcessor(
 | 
			
		||||
            stop_words_ids=stop_words_ids,
 | 
			
		||||
            eos_token_id=generation_config.eos_token_id,
 | 
			
		||||
        )
 | 
			
		||||
        logits_processor = LogitsProcessorList([stop_words_logits_processor])
 | 
			
		||||
        logits_processor = self._get_logits_processor(
 | 
			
		||||
            generation_config=generation_config,
 | 
			
		||||
            input_ids_seq_length=input_ids_length,
 | 
			
		||||
| 
						 | 
				
			
			@ -535,16 +534,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
            negative_prompt_attention_mask=None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 9. prepare stopping criteria
 | 
			
		||||
        
 | 
			
		||||
        stopping_criteria = StoppingCriteriaList()
 | 
			
		||||
        stopping_criteria = self._get_stopping_criteria(
 | 
			
		||||
            generation_config=generation_config, stopping_criteria=stopping_criteria
 | 
			
		||||
        )
 | 
			
		||||
        # 10. go into different generation modes
 | 
			
		||||
 | 
			
		||||
        # 11. prepare logits warper
 | 
			
		||||
        logits_warper = self._get_logits_warper(generation_config)
 | 
			
		||||
 | 
			
		||||
        # 12. expand input_ids with `num_return_sequences` additional sequences per batch
 | 
			
		||||
        input_ids, model_kwargs = self._expand_inputs_for_generation(
 | 
			
		||||
| 
						 | 
				
			
			@ -555,42 +544,24 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
        # 13. run sample
 | 
			
		||||
        return self.sample_base(
 | 
			
		||||
            input_ids,
 | 
			
		||||
            logits_processor=logits_processor,
 | 
			
		||||
            logits_warper=logits_warper,
 | 
			
		||||
            stopping_criteria=stopping_criteria,
 | 
			
		||||
            pad_token_id=generation_config.pad_token_id,
 | 
			
		||||
            eos_token_id=generation_config.eos_token_id,
 | 
			
		||||
            output_scores=generation_config.output_scores,
 | 
			
		||||
            streamer=streamer,
 | 
			
		||||
            **model_kwargs,
 | 
			
		||||
 | 
			
		||||
        pad_token_id=generation_config.pad_token_id
 | 
			
		||||
        eos_token_id=generation_config.eos_token_id
 | 
			
		||||
        streamer=streamer
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        # init values
 | 
			
		||||
        stopping_criteria = self._get_stopping_criteria(
 | 
			
		||||
            generation_config=generation_config, stopping_criteria=StoppingCriteriaList()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def sample_base(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor,
 | 
			
		||||
        logits_processor: Optional[LogitsProcessorList] = None,
 | 
			
		||||
        stopping_criteria: Optional[StoppingCriteriaList] = None,
 | 
			
		||||
        logits_warper: Optional[LogitsProcessorList] = None,
 | 
			
		||||
        max_length: Optional[int] = None,
 | 
			
		||||
        pad_token_id: Optional[int] = None,
 | 
			
		||||
        eos_token_id: Optional[Union[int, List[int]]] = None,
 | 
			
		||||
        output_scores: Optional[bool] = None,
 | 
			
		||||
        streamer: Optional["BaseStreamer"] = None,
 | 
			
		||||
        **model_kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        # init values
 | 
			
		||||
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
 | 
			
		||||
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
 | 
			
		||||
        logits_warper = self._get_logits_warper(generation_config)
 | 
			
		||||
 | 
			
		||||
        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
 | 
			
		||||
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
 | 
			
		||||
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
 | 
			
		||||
        if isinstance(eos_token_id, int):
 | 
			
		||||
            eos_token_id = [eos_token_id]
 | 
			
		||||
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
 | 
			
		||||
        output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
 | 
			
		||||
 | 
			
		||||
        # init attention / hidden states / scores tuples
 | 
			
		||||
        scores = None
 | 
			
		||||
| 
						 | 
				
			
			@ -607,10 +578,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
            # forward pass to get next token
 | 
			
		||||
            outputs = self(**model_inputs)
 | 
			
		||||
 | 
			
		||||
            next_token_logits = outputs.logits[:, -1, :]
 | 
			
		||||
            next_token_scores = outputs.logits[:, -1, :]
 | 
			
		||||
 | 
			
		||||
            # pre-process distribution
 | 
			
		||||
            next_token_scores = logits_processor(input_ids, next_token_logits)
 | 
			
		||||
            next_token_scores = logits_processor(input_ids, next_token_scores)
 | 
			
		||||
            next_token_scores = logits_warper(input_ids, next_token_scores)
 | 
			
		||||
 | 
			
		||||
            # sample
 | 
			
		||||
| 
						 | 
				
			
			@ -619,8 +590,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
 | 
			
		||||
            # finished sentences should have their next token be a padding token
 | 
			
		||||
            if eos_token_id is not None:
 | 
			
		||||
                if pad_token_id is None:
 | 
			
		||||
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
 | 
			
		||||
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
 | 
			
		||||
 | 
			
		||||
            # update generated ids, model inputs, and length for next step
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue