From d13f7e6c5753147a9c939a00a0c1026a30a4dd08 Mon Sep 17 00:00:00 2001 From: Colin Date: Sat, 13 Jan 2024 17:16:43 +0800 Subject: [PATCH] Format model of qwen. --- qwen/modeling_qwen.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 856fbff..e79996d 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -38,6 +38,7 @@ from qwen_generation_utils import ( ) import sys + sys.path.append("..") from tools import show @@ -144,7 +145,6 @@ class QWenAttention(nn.Module): else: attention_mask = causal_mask - # qk = query @ key.transpose(-2, -1) # qk = qk[0] # show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png") @@ -441,12 +441,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): if max_window_size is None: max_window_size = generation_config.max_window_size raw_text, context_tokens = make_context( - tokenizer, - query, - query_assistant, - history=history, - system=system, - max_window_size=max_window_size + tokenizer, query, query_assistant, history=history, system=system, max_window_size=max_window_size ) stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]]) @@ -469,13 +464,12 @@ class QWenLMHeadModel(QWenPreTrainedModel): def generate( self, inputs: Optional[torch.Tensor] = None, - stop_words_ids = [], + stop_words_ids=[], prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: generation_config = self.generation_config - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() @@ -486,7 +480,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): # 2. Set generation parameters if not already defined - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( @@ -524,7 +517,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): generation_config.max_length = generation_config.max_new_tokens + input_ids_length self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) - stop_words_logits_processor = StopWordsLogitsProcessor( stop_words_ids=stop_words_ids, eos_token_id=generation_config.eos_token_id, @@ -541,7 +533,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): negative_prompt_attention_mask=None, ) - # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, @@ -552,9 +543,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): # 13. run sample - pad_token_id=generation_config.pad_token_id - eos_token_id_tensor=torch.tensor([generation_config.eos_token_id]).to(input_ids.device) - + pad_token_id = generation_config.pad_token_id + eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device) + # init values stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=StoppingCriteriaList()