diff --git a/qwen/demo.py b/qwen/demo.py index 382c419..6b7af99 100644 --- a/qwen/demo.py +++ b/qwen/demo.py @@ -52,7 +52,6 @@ print(model) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) -# model = model.from_pretrained(model_dir, config=config, device_map="cuda:1", trust_remote_code=True) model = model.from_pretrained(model_dir).cuda() # model = model.eval() diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 6d87982..14dbf0c 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -1,8 +1,3 @@ -# Copyright (c) Alibaba Cloud. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - import copy import math import inspect @@ -443,30 +438,14 @@ class QWenLMHeadModel(nn.Module): **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() - # 2. Set generation parameters if not already defined - - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - generation_config.pad_token_id = eos_token_id - - # 6. Prepare `max_length` depending on other stopping criteria. - input_ids_length = input_ids.shape[-1] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - generation_config.max_length = generation_config.max_new_tokens + input_ids_length - pad_token_id = generation_config.pad_token_id eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device) - # init attention / hidden states / scores tuples scores = None - # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) @@ -477,15 +456,16 @@ class QWenLMHeadModel(nn.Module): # forward pass to get next token outputs = self(**model_inputs) - next_token_scores = outputs.logits[:, -1, :] + # repetition_penalty penalty = self.config.repetition_penalty score = torch.gather(next_token_scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities score = torch.where(score < 0, score * penalty, score / penalty) next_token_scores = next_token_scores.scatter_(1, input_ids, score) + # top_p top_p = self.config.top_p filter_value = -float("Inf") min_tokens_to_keep = 1