Remote return_dict_in_generate
This commit is contained in:
parent
a8f2fbbff5
commit
3f8ea9db07
|
@ -724,7 +724,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
outputs = self.generate(
|
outputs = self.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
stop_words_ids=stop_words_ids,
|
stop_words_ids=stop_words_ids,
|
||||||
return_dict_in_generate=False,
|
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -810,7 +809,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
outputs = []
|
outputs = []
|
||||||
for token in self.generate_stream(
|
for token in self.generate_stream(
|
||||||
input_ids,
|
input_ids,
|
||||||
return_dict_in_generate=False,
|
|
||||||
generation_config=stream_config,
|
generation_config=stream_config,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
seed=-1,
|
seed=-1,
|
||||||
|
@ -1074,7 +1072,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
|
@ -1092,7 +1089,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
|
||||||
synced_gpus: bool = False,
|
synced_gpus: bool = False,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
|
@ -1148,23 +1144,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
if output_hidden_states is not None
|
if output_hidden_states is not None
|
||||||
else self.generation_config.output_hidden_states
|
else self.generation_config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_dict_in_generate = (
|
|
||||||
return_dict_in_generate
|
|
||||||
if return_dict_in_generate is not None
|
|
||||||
else self.generation_config.return_dict_in_generate
|
|
||||||
)
|
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = None
|
||||||
decoder_attentions = (
|
|
||||||
() if (return_dict_in_generate and output_attentions) else None
|
|
||||||
)
|
|
||||||
cross_attentions = (
|
|
||||||
() if (return_dict_in_generate and output_attentions) else None
|
|
||||||
)
|
|
||||||
decoder_hidden_states = (
|
|
||||||
() if (return_dict_in_generate and output_hidden_states) else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
unfinished_sequences = torch.ones(
|
unfinished_sequences = torch.ones(
|
||||||
|
@ -1190,16 +1172,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||||
|
|
||||||
# Store scores, attentions and hidden_states when required
|
|
||||||
if return_dict_in_generate:
|
|
||||||
if output_scores:
|
|
||||||
scores += (next_token_scores,)
|
|
||||||
if output_attentions:
|
|
||||||
decoder_attentions += (outputs.attentions,)
|
|
||||||
|
|
||||||
if output_hidden_states:
|
|
||||||
decoder_hidden_states += (outputs.hidden_states,)
|
|
||||||
|
|
||||||
# sample
|
# sample
|
||||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
|
Loading…
Reference in New Issue