Remote return_dict_in_generate

This commit is contained in:
Colin 2024-01-07 17:32:24 +08:00
parent a8f2fbbff5
commit 3f8ea9db07
1 changed files with 1 additions and 29 deletions

View File

@ -724,7 +724,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
outputs = self.generate(
input_ids,
stop_words_ids=stop_words_ids,
return_dict_in_generate=False,
generation_config=generation_config,
**kwargs,
)
@ -810,7 +809,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
outputs = []
for token in self.generate_stream(
input_ids,
return_dict_in_generate=False,
generation_config=stream_config,
logits_processor=logits_processor,
seed=-1,
@ -1074,7 +1072,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
@ -1092,7 +1089,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
@ -1148,23 +1144,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if output_hidden_states is not None
else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
scores = None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(
@ -1190,16 +1172,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (outputs.attentions,)
if output_hidden_states:
decoder_hidden_states += (outputs.hidden_states,)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)