Remote return_dict_in_generate
This commit is contained in:
parent
a8f2fbbff5
commit
3f8ea9db07
|
@ -724,7 +724,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
outputs = self.generate(
|
||||
input_ids,
|
||||
stop_words_ids=stop_words_ids,
|
||||
return_dict_in_generate=False,
|
||||
generation_config=generation_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -810,7 +809,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
outputs = []
|
||||
for token in self.generate_stream(
|
||||
input_ids,
|
||||
return_dict_in_generate=False,
|
||||
generation_config=stream_config,
|
||||
logits_processor=logits_processor,
|
||||
seed=-1,
|
||||
|
@ -1074,7 +1072,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
|
@ -1092,7 +1089,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: bool = False,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
|
@ -1148,23 +1144,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
if output_hidden_states is not None
|
||||
else self.generation_config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate
|
||||
if return_dict_in_generate is not None
|
||||
else self.generation_config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
scores = None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(
|
||||
|
@ -1190,16 +1172,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (outputs.attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (outputs.hidden_states,)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
|
Loading…
Reference in New Issue