Refine model of qwen.
This commit is contained in:
parent
12dcbec718
commit
4d493014ba
|
@ -52,7 +52,6 @@ print(model)
|
||||||
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
# model = model.from_pretrained(model_dir, config=config, device_map="cuda:1", trust_remote_code=True)
|
|
||||||
model = model.from_pretrained(model_dir).cuda()
|
model = model.from_pretrained(model_dir).cuda()
|
||||||
|
|
||||||
# model = model.eval()
|
# model = model.eval()
|
||||||
|
|
|
@ -1,8 +1,3 @@
|
||||||
# Copyright (c) Alibaba Cloud.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -443,30 +438,14 @@ class QWenLMHeadModel(nn.Module):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
generation_config = copy.deepcopy(generation_config)
|
generation_config = copy.deepcopy(generation_config)
|
||||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||||
generation_config.validate()
|
generation_config.validate()
|
||||||
|
|
||||||
# 2. Set generation parameters if not already defined
|
|
||||||
|
|
||||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
|
||||||
eos_token_id = generation_config.eos_token_id
|
|
||||||
if isinstance(eos_token_id, list):
|
|
||||||
eos_token_id = eos_token_id[0]
|
|
||||||
generation_config.pad_token_id = eos_token_id
|
|
||||||
|
|
||||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
|
||||||
input_ids_length = input_ids.shape[-1]
|
|
||||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
|
||||||
|
|
||||||
pad_token_id = generation_config.pad_token_id
|
pad_token_id = generation_config.pad_token_id
|
||||||
eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
|
eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
|
||||||
scores = None
|
scores = None
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||||
|
|
||||||
|
@ -477,15 +456,16 @@ class QWenLMHeadModel(nn.Module):
|
||||||
|
|
||||||
# forward pass to get next token
|
# forward pass to get next token
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
|
|
||||||
next_token_scores = outputs.logits[:, -1, :]
|
next_token_scores = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
|
# repetition_penalty
|
||||||
penalty = self.config.repetition_penalty
|
penalty = self.config.repetition_penalty
|
||||||
score = torch.gather(next_token_scores, 1, input_ids)
|
score = torch.gather(next_token_scores, 1, input_ids)
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||||
score = torch.where(score < 0, score * penalty, score / penalty)
|
score = torch.where(score < 0, score * penalty, score / penalty)
|
||||||
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
|
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
|
||||||
|
|
||||||
|
# top_p
|
||||||
top_p = self.config.top_p
|
top_p = self.config.top_p
|
||||||
filter_value = -float("Inf")
|
filter_value = -float("Inf")
|
||||||
min_tokens_to_keep = 1
|
min_tokens_to_keep = 1
|
||||||
|
|
Loading…
Reference in New Issue