Refine.
This commit is contained in:
parent
c462129ba6
commit
bfc3fb6706
|
@ -752,7 +752,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
)
|
)
|
||||||
next_token_logits = logits[:, 0, :]
|
next_token_logits = logits[:, -1, :]
|
||||||
probs = nn.functional.softmax(next_token_logits, dim=-1)
|
probs = nn.functional.softmax(next_token_logits, dim=-1)
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue