Update qwen model.

This commit is contained in:
Colin 2024-01-10 13:16:54 +00:00
parent 245d251663
commit 7d7b4381f8
2 changed files with 14 additions and 61 deletions

View File

@ -34,7 +34,6 @@ from qwen_generation_utils import (
HistoryType, HistoryType,
make_context, make_context,
decode_tokens, decode_tokens,
get_stop_words_ids,
StopWordsLogitsProcessor, StopWordsLogitsProcessor,
) )
@ -416,18 +415,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
query_assistant: str, query_assistant: str,
history: Optional[HistoryType], history: Optional[HistoryType],
system: str = "You are a helpful assistant.", system: str = "You are a helpful assistant.",
stop_words_ids: Optional[List[List[int]]] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs, **kwargs,
) -> Tuple[str, HistoryType]: ) -> Tuple[str, HistoryType]:
generation_config = generation_config if generation_config is not None else self.generation_config generation_config = self.generation_config
if history is None: if history is None:
history = [] history = []
else: else:
history = copy.deepcopy(history) history = copy.deepcopy(history)
if stop_words_ids is None:
stop_words_ids = [] stop_words_ids = []
max_window_size = kwargs.get("max_window_size", None) max_window_size = kwargs.get("max_window_size", None)
@ -442,12 +438,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
max_window_size=max_window_size max_window_size=max_window_size
) )
stop_words_ids.extend(get_stop_words_ids(tokenizer)) stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
input_ids = torch.tensor([context_tokens]).to(self.device) input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate( outputs = self.generate(
input_ids, input_ids,
stop_words_ids=stop_words_ids, stop_words_ids=stop_words_ids,
generation_config=generation_config,
**kwargs, **kwargs,
) )
decoded, response, end_reason = decode_tokens( decoded, response, end_reason = decode_tokens(
@ -463,57 +458,20 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def generate( def generate(
self, self,
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None, stop_words_ids = [],
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = generation_config if generation_config is not None else self.generation_config generation_config = self.generation_config
# Process stop_words_ids. # Process stop_words_ids.
stop_words_ids = kwargs.pop("stop_words_ids", None)
if stop_words_ids is None and generation_config is not None:
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
if stop_words_ids is None:
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor( stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids, stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
) )
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor]) logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
return self.generate_base(
inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
assistant_model=assistant_model,
streamer=streamer,
**kwargs,
)
def generate_base(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class() self._validate_model_class()
@ -523,8 +481,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None: if model_kwargs.get("attention_mask", None) is None:
@ -574,11 +531,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor, logits_processor=logits_processor,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids, negative_prompt_ids=None,
negative_prompt_attention_mask=negative_prompt_attention_mask, negative_prompt_attention_mask=None,
) )
# 9. prepare stopping criteria # 9. prepare stopping criteria
stopping_criteria = StoppingCriteriaList()
stopping_criteria = self._get_stopping_criteria( stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria generation_config=generation_config, stopping_criteria=stopping_criteria
) )

View File

@ -105,12 +105,6 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
) )
return tokens, attention_mask, position_ids return tokens, attention_mask, position_ids
def get_stop_words_ids(tokenizer):
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
return stop_words_ids
def make_context( def make_context(
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
query: str, query: str,