Update qwen model.
This commit is contained in:
		
							parent
							
								
									245d251663
								
							
						
					
					
						commit
						7d7b4381f8
					
				| 
						 | 
					@ -34,7 +34,6 @@ from qwen_generation_utils import (
 | 
				
			||||||
    HistoryType,
 | 
					    HistoryType,
 | 
				
			||||||
    make_context,
 | 
					    make_context,
 | 
				
			||||||
    decode_tokens,
 | 
					    decode_tokens,
 | 
				
			||||||
    get_stop_words_ids,
 | 
					 | 
				
			||||||
    StopWordsLogitsProcessor,
 | 
					    StopWordsLogitsProcessor,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -416,18 +415,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
        query_assistant: str,
 | 
					        query_assistant: str,
 | 
				
			||||||
        history: Optional[HistoryType],
 | 
					        history: Optional[HistoryType],
 | 
				
			||||||
        system: str = "You are a helpful assistant.",
 | 
					        system: str = "You are a helpful assistant.",
 | 
				
			||||||
        stop_words_ids: Optional[List[List[int]]] = None,
 | 
					 | 
				
			||||||
        generation_config: Optional[GenerationConfig] = None,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ) -> Tuple[str, HistoryType]:
 | 
					    ) -> Tuple[str, HistoryType]:
 | 
				
			||||||
        generation_config = generation_config if generation_config is not None else self.generation_config
 | 
					        generation_config = self.generation_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if history is None:
 | 
					        if history is None:
 | 
				
			||||||
            history = []
 | 
					            history = []
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            history = copy.deepcopy(history)
 | 
					            history = copy.deepcopy(history)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if stop_words_ids is None:
 | 
					 | 
				
			||||||
        stop_words_ids = []
 | 
					        stop_words_ids = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        max_window_size = kwargs.get("max_window_size", None)
 | 
					        max_window_size = kwargs.get("max_window_size", None)
 | 
				
			||||||
| 
						 | 
					@ -442,12 +438,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
            max_window_size=max_window_size
 | 
					            max_window_size=max_window_size
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        stop_words_ids.extend(get_stop_words_ids(tokenizer))
 | 
					        stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
 | 
				
			||||||
        input_ids = torch.tensor([context_tokens]).to(self.device)
 | 
					        input_ids = torch.tensor([context_tokens]).to(self.device)
 | 
				
			||||||
        outputs = self.generate(
 | 
					        outputs = self.generate(
 | 
				
			||||||
            input_ids,
 | 
					            input_ids,
 | 
				
			||||||
            stop_words_ids=stop_words_ids,
 | 
					            stop_words_ids=stop_words_ids,
 | 
				
			||||||
            generation_config=generation_config,
 | 
					 | 
				
			||||||
            **kwargs,
 | 
					            **kwargs,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        decoded, response, end_reason = decode_tokens(
 | 
					        decoded, response, end_reason = decode_tokens(
 | 
				
			||||||
| 
						 | 
					@ -463,57 +458,20 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
    def generate(
 | 
					    def generate(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        inputs: Optional[torch.Tensor] = None,
 | 
					        inputs: Optional[torch.Tensor] = None,
 | 
				
			||||||
        generation_config: Optional[GenerationConfig] = None,
 | 
					        stop_words_ids = [],
 | 
				
			||||||
        logits_processor: Optional[LogitsProcessorList] = None,
 | 
					 | 
				
			||||||
        stopping_criteria: Optional[StoppingCriteriaList] = None,
 | 
					 | 
				
			||||||
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
 | 
					        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
 | 
				
			||||||
        assistant_model: Optional["PreTrainedModel"] = None,
 | 
					 | 
				
			||||||
        streamer: Optional["BaseStreamer"] = None,
 | 
					        streamer: Optional["BaseStreamer"] = None,
 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ) -> Union[GenerateOutput, torch.LongTensor]:
 | 
					    ) -> Union[GenerateOutput, torch.LongTensor]:
 | 
				
			||||||
        generation_config = generation_config if generation_config is not None else self.generation_config
 | 
					        generation_config = self.generation_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Process stop_words_ids.
 | 
					        # Process stop_words_ids.
 | 
				
			||||||
        stop_words_ids = kwargs.pop("stop_words_ids", None)
 | 
					 | 
				
			||||||
        if stop_words_ids is None and generation_config is not None:
 | 
					 | 
				
			||||||
            stop_words_ids = getattr(generation_config, "stop_words_ids", None)
 | 
					 | 
				
			||||||
        if stop_words_ids is None:
 | 
					 | 
				
			||||||
            stop_words_ids = getattr(generation_config, "stop_words_ids", None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if stop_words_ids is not None:
 | 
					 | 
				
			||||||
        stop_words_logits_processor = StopWordsLogitsProcessor(
 | 
					        stop_words_logits_processor = StopWordsLogitsProcessor(
 | 
				
			||||||
            stop_words_ids=stop_words_ids,
 | 
					            stop_words_ids=stop_words_ids,
 | 
				
			||||||
            eos_token_id=generation_config.eos_token_id,
 | 
					            eos_token_id=generation_config.eos_token_id,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
            if logits_processor is None:
 | 
					 | 
				
			||||||
        logits_processor = LogitsProcessorList([stop_words_logits_processor])
 | 
					        logits_processor = LogitsProcessorList([stop_words_logits_processor])
 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                logits_processor.append(stop_words_logits_processor)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return self.generate_base(
 | 
					 | 
				
			||||||
            inputs,
 | 
					 | 
				
			||||||
            generation_config=generation_config,
 | 
					 | 
				
			||||||
            logits_processor=logits_processor,
 | 
					 | 
				
			||||||
            stopping_criteria=stopping_criteria,
 | 
					 | 
				
			||||||
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
					 | 
				
			||||||
            assistant_model=assistant_model,
 | 
					 | 
				
			||||||
            streamer=streamer,
 | 
					 | 
				
			||||||
            **kwargs,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate_base(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        inputs: Optional[torch.Tensor] = None,
 | 
					 | 
				
			||||||
        generation_config: Optional[GenerationConfig] = None,
 | 
					 | 
				
			||||||
        logits_processor: Optional[LogitsProcessorList] = None,
 | 
					 | 
				
			||||||
        stopping_criteria: Optional[StoppingCriteriaList] = None,
 | 
					 | 
				
			||||||
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
 | 
					 | 
				
			||||||
        assistant_model: Optional["PreTrainedModel"] = None,
 | 
					 | 
				
			||||||
        streamer: Optional["BaseStreamer"] = None,
 | 
					 | 
				
			||||||
        negative_prompt_ids: Optional[torch.Tensor] = None,
 | 
					 | 
				
			||||||
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					 | 
				
			||||||
    ) -> Union[GenerateOutput, torch.LongTensor]:
 | 
					 | 
				
			||||||
        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
 | 
					        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
 | 
				
			||||||
        self._validate_model_class()
 | 
					        self._validate_model_class()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -523,8 +481,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
        self._validate_model_kwargs(model_kwargs.copy())
 | 
					        self._validate_model_kwargs(model_kwargs.copy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 2. Set generation parameters if not already defined
 | 
					        # 2. Set generation parameters if not already defined
 | 
				
			||||||
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
 | 
					
 | 
				
			||||||
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
 | 
					        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
 | 
				
			||||||
            if model_kwargs.get("attention_mask", None) is None:
 | 
					            if model_kwargs.get("attention_mask", None) is None:
 | 
				
			||||||
| 
						 | 
					@ -574,11 +531,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
					            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
				
			||||||
            logits_processor=logits_processor,
 | 
					            logits_processor=logits_processor,
 | 
				
			||||||
            model_kwargs=model_kwargs,
 | 
					            model_kwargs=model_kwargs,
 | 
				
			||||||
            negative_prompt_ids=negative_prompt_ids,
 | 
					            negative_prompt_ids=None,
 | 
				
			||||||
            negative_prompt_attention_mask=negative_prompt_attention_mask,
 | 
					            negative_prompt_attention_mask=None,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 9. prepare stopping criteria
 | 
					        # 9. prepare stopping criteria
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        stopping_criteria = StoppingCriteriaList()
 | 
				
			||||||
        stopping_criteria = self._get_stopping_criteria(
 | 
					        stopping_criteria = self._get_stopping_criteria(
 | 
				
			||||||
            generation_config=generation_config, stopping_criteria=stopping_criteria
 | 
					            generation_config=generation_config, stopping_criteria=stopping_criteria
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -105,12 +105,6 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    return tokens, attention_mask, position_ids
 | 
					    return tokens, attention_mask, position_ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_stop_words_ids(tokenizer):
 | 
					 | 
				
			||||||
    stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
 | 
					 | 
				
			||||||
    return stop_words_ids
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def make_context(
 | 
					def make_context(
 | 
				
			||||||
    tokenizer: PreTrainedTokenizer,
 | 
					    tokenizer: PreTrainedTokenizer,
 | 
				
			||||||
    query: str,
 | 
					    query: str,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue