Format model of qwen.

This commit is contained in:
Colin 2024-01-13 17:16:43 +08:00
parent 5cf6e8b013
commit d13f7e6c57
1 changed files with 6 additions and 15 deletions

View File

@ -38,6 +38,7 @@ from qwen_generation_utils import (
) )
import sys import sys
sys.path.append("..") sys.path.append("..")
from tools import show from tools import show
@ -144,7 +145,6 @@ class QWenAttention(nn.Module):
else: else:
attention_mask = causal_mask attention_mask = causal_mask
# qk = query @ key.transpose(-2, -1) # qk = query @ key.transpose(-2, -1)
# qk = qk[0] # qk = qk[0]
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png") # show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
@ -441,12 +441,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if max_window_size is None: if max_window_size is None:
max_window_size = generation_config.max_window_size max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context( raw_text, context_tokens = make_context(
tokenizer, tokenizer, query, query_assistant, history=history, system=system, max_window_size=max_window_size
query,
query_assistant,
history=history,
system=system,
max_window_size=max_window_size
) )
stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]]) stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
@ -469,13 +464,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def generate( def generate(
self, self,
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
stop_words_ids = [], stop_words_ids=[],
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = self.generation_config generation_config = self.generation_config
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class() self._validate_model_class()
@ -486,7 +480,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None: if model_kwargs.get("attention_mask", None) is None:
logger.warning( logger.warning(
@ -524,7 +517,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config.max_length = generation_config.max_new_tokens + input_ids_length generation_config.max_length = generation_config.max_new_tokens + input_ids_length
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
stop_words_logits_processor = StopWordsLogitsProcessor( stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids, stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
@ -541,7 +533,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
negative_prompt_attention_mask=None, negative_prompt_attention_mask=None,
) )
# 12. expand input_ids with `num_return_sequences` additional sequences per batch # 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids, input_ids=input_ids,
@ -552,8 +543,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# 13. run sample # 13. run sample
pad_token_id=generation_config.pad_token_id pad_token_id = generation_config.pad_token_id
eos_token_id_tensor=torch.tensor([generation_config.eos_token_id]).to(input_ids.device) eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
# init values # init values
stopping_criteria = self._get_stopping_criteria( stopping_criteria = self._get_stopping_criteria(