Refine model of qwen.
This commit is contained in:
parent
12dcbec718
commit
4d493014ba
|
@ -52,7 +52,6 @@ print(model)
|
|||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
# model = model.from_pretrained(model_dir, config=config, device_map="cuda:1", trust_remote_code=True)
|
||||
model = model.from_pretrained(model_dir).cuda()
|
||||
|
||||
# model = model.eval()
|
||||
|
|
|
@ -1,8 +1,3 @@
|
|||
# Copyright (c) Alibaba Cloud.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import math
|
||||
import inspect
|
||||
|
@ -443,30 +438,14 @@ class QWenLMHeadModel(nn.Module):
|
|||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
generation_config = self.generation_config
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
generation_config.validate()
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
|
||||
|
@ -477,15 +456,16 @@ class QWenLMHeadModel(nn.Module):
|
|||
|
||||
# forward pass to get next token
|
||||
outputs = self(**model_inputs)
|
||||
|
||||
next_token_scores = outputs.logits[:, -1, :]
|
||||
|
||||
# repetition_penalty
|
||||
penalty = self.config.repetition_penalty
|
||||
score = torch.gather(next_token_scores, 1, input_ids)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||
score = torch.where(score < 0, score * penalty, score / penalty)
|
||||
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
|
||||
|
||||
# top_p
|
||||
top_p = self.config.top_p
|
||||
filter_value = -float("Inf")
|
||||
min_tokens_to_keep = 1
|
||||
|
|
Loading…
Reference in New Issue