Refine model of qwen.

This commit is contained in:
Colin 2024-01-20 20:20:18 +08:00
parent 12dcbec718
commit 4d493014ba
2 changed files with 2 additions and 23 deletions

View File

@ -52,7 +52,6 @@ print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
# model = model.from_pretrained(model_dir, config=config, device_map="cuda:1", trust_remote_code=True)
model = model.from_pretrained(model_dir).cuda()
# model = model.eval()

View File

@ -1,8 +1,3 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import math
import inspect
@ -443,30 +438,14 @@ class QWenLMHeadModel(nn.Module):
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
# 2. Set generation parameters if not already defined
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
generation_config.pad_token_id = eos_token_id
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
pad_token_id = generation_config.pad_token_id
eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
# init attention / hidden states / scores tuples
scores = None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
@ -477,15 +456,16 @@ class QWenLMHeadModel(nn.Module):
# forward pass to get next token
outputs = self(**model_inputs)
next_token_scores = outputs.logits[:, -1, :]
# repetition_penalty
penalty = self.config.repetition_penalty
score = torch.gather(next_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
# top_p
top_p = self.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1