Update qwen model.
This commit is contained in:
parent
245d251663
commit
7d7b4381f8
|
@ -34,7 +34,6 @@ from qwen_generation_utils import (
|
||||||
HistoryType,
|
HistoryType,
|
||||||
make_context,
|
make_context,
|
||||||
decode_tokens,
|
decode_tokens,
|
||||||
get_stop_words_ids,
|
|
||||||
StopWordsLogitsProcessor,
|
StopWordsLogitsProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -416,18 +415,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
query_assistant: str,
|
query_assistant: str,
|
||||||
history: Optional[HistoryType],
|
history: Optional[HistoryType],
|
||||||
system: str = "You are a helpful assistant.",
|
system: str = "You are a helpful assistant.",
|
||||||
stop_words_ids: Optional[List[List[int]]] = None,
|
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[str, HistoryType]:
|
) -> Tuple[str, HistoryType]:
|
||||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
else:
|
else:
|
||||||
history = copy.deepcopy(history)
|
history = copy.deepcopy(history)
|
||||||
|
|
||||||
if stop_words_ids is None:
|
|
||||||
stop_words_ids = []
|
stop_words_ids = []
|
||||||
|
|
||||||
max_window_size = kwargs.get("max_window_size", None)
|
max_window_size = kwargs.get("max_window_size", None)
|
||||||
|
@ -442,12 +438,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
max_window_size=max_window_size
|
max_window_size=max_window_size
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_words_ids.extend(get_stop_words_ids(tokenizer))
|
stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
|
||||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||||
outputs = self.generate(
|
outputs = self.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
stop_words_ids=stop_words_ids,
|
stop_words_ids=stop_words_ids,
|
||||||
generation_config=generation_config,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
decoded, response, end_reason = decode_tokens(
|
decoded, response, end_reason = decode_tokens(
|
||||||
|
@ -463,57 +458,20 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
stop_words_ids = [],
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
# Process stop_words_ids.
|
# Process stop_words_ids.
|
||||||
stop_words_ids = kwargs.pop("stop_words_ids", None)
|
|
||||||
if stop_words_ids is None and generation_config is not None:
|
|
||||||
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
|
|
||||||
if stop_words_ids is None:
|
|
||||||
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
|
|
||||||
|
|
||||||
if stop_words_ids is not None:
|
|
||||||
stop_words_logits_processor = StopWordsLogitsProcessor(
|
stop_words_logits_processor = StopWordsLogitsProcessor(
|
||||||
stop_words_ids=stop_words_ids,
|
stop_words_ids=stop_words_ids,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
)
|
)
|
||||||
if logits_processor is None:
|
|
||||||
logits_processor = LogitsProcessorList([stop_words_logits_processor])
|
logits_processor = LogitsProcessorList([stop_words_logits_processor])
|
||||||
else:
|
|
||||||
logits_processor.append(stop_words_logits_processor)
|
|
||||||
|
|
||||||
return self.generate_base(
|
|
||||||
inputs,
|
|
||||||
generation_config=generation_config,
|
|
||||||
logits_processor=logits_processor,
|
|
||||||
stopping_criteria=stopping_criteria,
|
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
|
||||||
assistant_model=assistant_model,
|
|
||||||
streamer=streamer,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_base(
|
|
||||||
self,
|
|
||||||
inputs: Optional[torch.Tensor] = None,
|
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
|
||||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
|
||||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
||||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||||
self._validate_model_class()
|
self._validate_model_class()
|
||||||
|
|
||||||
|
@ -523,8 +481,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
self._validate_model_kwargs(model_kwargs.copy())
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
||||||
|
|
||||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
|
@ -574,11 +531,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
negative_prompt_ids=negative_prompt_ids,
|
negative_prompt_ids=None,
|
||||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
negative_prompt_attention_mask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 9. prepare stopping criteria
|
# 9. prepare stopping criteria
|
||||||
|
|
||||||
|
stopping_criteria = StoppingCriteriaList()
|
||||||
stopping_criteria = self._get_stopping_criteria(
|
stopping_criteria = self._get_stopping_criteria(
|
||||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||||
)
|
)
|
||||||
|
|
|
@ -105,12 +105,6 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
|
||||||
)
|
)
|
||||||
return tokens, attention_mask, position_ids
|
return tokens, attention_mask, position_ids
|
||||||
|
|
||||||
|
|
||||||
def get_stop_words_ids(tokenizer):
|
|
||||||
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
|
|
||||||
return stop_words_ids
|
|
||||||
|
|
||||||
|
|
||||||
def make_context(
|
def make_context(
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
Loading…
Reference in New Issue