Compare commits

...

3 Commits

Author SHA1 Message Date
Colin 4d493014ba Refine model of qwen. 2024-01-20 20:20:18 +08:00
Colin 12dcbec718 PreTrainedModel to mm.Module 2024-01-20 20:06:59 +08:00
Colin 0458e7303c Remove attention_mask 2024-01-20 18:08:20 +08:00
4 changed files with 127 additions and 184 deletions

View File

@ -33,5 +33,14 @@
"use_dynamic_ntk": true, "use_dynamic_ntk": true,
"use_flash_attn": "auto", "use_flash_attn": "auto",
"use_logn_attn": true, "use_logn_attn": true,
"vocab_size": 151936 "vocab_size": 151936,
"chat_format": "chatml",
"eos_token_id": 151643,
"pad_token_id": 151643,
"max_window_size": 6144,
"max_new_tokens": 512,
"do_sample": true,
"top_k": 0,
"top_p": 0.8,
"repetition_penalty": 1.1
} }

View File

@ -52,10 +52,10 @@ print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir, config=config, device_map="auto", trust_remote_code=True) model = model.from_pretrained(model_dir).cuda()
# model = model.eval() # model = model.eval()
model = model.train() # control by @torch.no_grad() model = model.train() # control by @torch.no_grad()
# 可指定不同的生成长度、top_p等相关超参 # 可指定不同的生成长度、top_p等相关超参
# model.generation_config = GenerationConfig.from_pretrained( # model.generation_config = GenerationConfig.from_pretrained(
@ -80,7 +80,7 @@ print(decode_tokens)
# 第二轮对话 # 第二轮对话
response, history, decode_tokens = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", "", history=None) response, history, decode_tokens = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", "", history=None)
print(response) print(decode_tokens)
# <|im_start|>system # <|im_start|>system

View File

@ -1,12 +0,0 @@
{
"chat_format": "chatml",
"eos_token_id": 151643,
"pad_token_id": 151643,
"max_window_size": 6144,
"max_new_tokens": 512,
"do_sample": true,
"top_k": 0,
"top_p": 0.8,
"repetition_penalty": 1.1,
"transformers_version": "4.31.0"
}

View File

@ -1,11 +1,10 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy import copy
import math import math
import inspect import inspect
import os
import gc
from tqdm import auto as tqdm_lib
import json
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch import torch
@ -37,6 +36,11 @@ from qwen_generation_utils import (
StopWordsLogitsProcessor, StopWordsLogitsProcessor,
) )
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
import sys import sys
sys.path.append("..") sys.path.append("..")
@ -96,7 +100,6 @@ class QWenAttention(nn.Module):
self, self,
hidden_states: Optional[Tuple[torch.FloatTensor]], hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
): ):
mixed_x_layer = self.c_attn(hidden_states) mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2) query, key, value = mixed_x_layer.split(self.split_size, dim=2)
@ -120,32 +123,21 @@ class QWenAttention(nn.Module):
query = query * logn_tensor.expand_as(query) query = query * logn_tensor.expand_as(query)
key_size = key.size(1) key_size = key.size(1)
if query.size(1) == key_size: causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view( 1, 1, key_size, key_size
1, 1, key_size, key_size )
)
else:
causal_mask = None
query = query.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3)
if attention_mask is not None:
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
if causal_mask is not None:
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
else:
attention_mask = causal_mask
# qk = query @ key.transpose(-2, -1) # qk = query @ key.transpose(-2, -1)
# qk = qk[0] # qk = qk[0]
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png") # prePath = "../generated/query_matmul_key/img/"
# print("layer:" + str(self.index) + " query.shape:"+ str(query.shape)) # show.DumpTensorToImage(
# print("layer:" + str(self.index) + " key.shape:"+ str(key.shape)) # qk, prePath + "q_matmul_k_sequence_" + str(key_size) + "_layer_" + str(self.index) + ".png"
# print("layer:" + str(self.index) + " value.shape:"+ str(value.shape)) # )
# print("\n")
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer) attn_output = self.c_proj(context_layer)
@ -189,15 +181,10 @@ class QWenBlock(nn.Module):
self, self,
hidden_states: Optional[Tuple[torch.FloatTensor]], hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
): ):
layernorm_output = self.ln_1(hidden_states) layernorm_output = self.ln_1(hidden_states)
attn_outputs = self.attn( attn_outputs = self.attn(layernorm_output, rotary_pos_emb_list)
layernorm_output,
rotary_pos_emb_list,
attention_mask=attention_mask,
)
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
residual = hidden_states residual = hidden_states
layernorm_input = attn_output + residual layernorm_input = attn_output + residual
@ -209,7 +196,7 @@ class QWenBlock(nn.Module):
return hidden_states return hidden_states
class QWenPreTrainedModel(PreTrainedModel): class QWenPreTrainedModel(nn.Module):
config_class = QWenConfig config_class = QWenConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
is_parallelizable = False is_parallelizable = False
@ -217,7 +204,7 @@ class QWenPreTrainedModel(PreTrainedModel):
_no_split_modules = ["QWenBlock"] _no_split_modules = ["QWenBlock"]
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__()
class QWenModel(QWenPreTrainedModel): class QWenModel(QWenPreTrainedModel):
@ -248,8 +235,6 @@ class QWenModel(QWenPreTrainedModel):
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.post_init()
def get_ntk_alpha(self, true_seq_len): def get_ntk_alpha(self, true_seq_len):
context_value = math.log(true_seq_len / self.seq_length, 2) + 1 context_value = math.log(true_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1 ntk_alpha = 2 ** math.ceil(context_value) - 1
@ -259,8 +244,6 @@ class QWenModel(QWenPreTrainedModel):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
): ):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
@ -275,14 +258,6 @@ class QWenModel(QWenPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -295,15 +270,8 @@ class QWenModel(QWenPreTrainedModel):
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
else: else:
ntk_alpha_list = [] ntk_alpha_list = []
if attention_mask is not None and kv_seq_len > self.seq_length: ntk_alpha = self.get_ntk_alpha(kv_seq_len)
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) ntk_alpha_list.append(ntk_alpha)
for i in range(hidden_states.size()[0]):
true_seq_len = true_seq_lens[i].item()
ntk_alpha = self.get_ntk_alpha(true_seq_len)
ntk_alpha_list.append(ntk_alpha)
else:
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
ntk_alpha_list.append(ntk_alpha)
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list] rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
@ -312,52 +280,34 @@ class QWenModel(QWenPreTrainedModel):
all_hidden_states = None all_hidden_states = None
for block in self.h: for block in self.h:
hidden_states = block( hidden_states = block(hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
hidden_states,
rotary_pos_emb_list=rotary_pos_emb_list,
attention_mask=attention_mask,
)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states) return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
class QWenLMHeadModel(QWenPreTrainedModel): class QWenLMHeadModel(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__()
self.config = config
self.transformer = QWenModel(config) self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init() self.generation_config = GenerationConfig.from_model_config(config)
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
if input_ids.size(0) == 1:
attention_mask = None
else:
attention_mask = kwargs.get("attention_mask", None)
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
}
)
return model_inputs return model_inputs
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
@ -387,6 +337,59 @@ class QWenLMHeadModel(QWenPreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
load_in_8bit = False
load_in_4bit = False
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
model = cls._load_pretrained_model(resolved_archive_file)
model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit
return model
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
def _load_pretrained_model(cls, resolved_archive_file):
start_prefix = ""
model_to_load = cls
error_msgs = []
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
for shard_file in resolved_archive_file:
state_dict = safe_load_file(shard_file)
error_msgs += cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
del state_dict # force memory release
gc.collect()
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
return cls
@torch.no_grad() @torch.no_grad()
def chat( def chat(
self, self,
@ -406,18 +409,14 @@ class QWenLMHeadModel(QWenPreTrainedModel):
stop_words_ids = [] stop_words_ids = []
max_window_size = kwargs.get("max_window_size", None) raw_text, context_tokens = make_context(tokenizer, query, query_assistant, history=history, system=system)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer, query, query_assistant, history=history, system=system, max_window_size=max_window_size
)
stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]]) stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
input_ids = torch.tensor([context_tokens]).to(self.device) input_ids = torch.tensor([context_tokens]).to(next(self.parameters()).device)
outputs = self.generate( outputs = self.generate(
input_ids, input_ids,
stop_words_ids=stop_words_ids, stop_words_ids=stop_words_ids,
tokenizer=tokenizer,
**kwargs, **kwargs,
) )
decoded, response, end_reason = decode_tokens( decoded, response, end_reason = decode_tokens(
@ -432,110 +431,53 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def generate( def generate(
self, self,
inputs: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
stop_words_ids=[], stop_words_ids=[],
tokenizer=None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = self.generation_config generation_config = self.generation_config
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate() generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
# 4. Define other model kwargs
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id,
)
logits_processor = LogitsProcessorList([stop_words_logits_processor])
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
negative_prompt_ids=None,
negative_prompt_attention_mask=None,
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=False,
**model_kwargs,
)
# 13. run sample
pad_token_id = generation_config.pad_token_id pad_token_id = generation_config.pad_token_id
eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device) eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
# init values
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=StoppingCriteriaList()
)
logits_warper = self._get_logits_warper(generation_config)
# init attention / hidden states / scores tuples
scores = None scores = None
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False this_peer_finished = False
# auto-regressive generation # auto-regressive generation
while True: while True:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token # forward pass to get next token
outputs = self(**model_inputs) outputs = self(**model_inputs)
next_token_scores = outputs.logits[:, -1, :] next_token_scores = outputs.logits[:, -1, :]
# pre-process distribution # repetition_penalty
next_token_scores = logits_processor(input_ids, next_token_scores) penalty = self.config.repetition_penalty
next_token_scores = logits_warper(input_ids, next_token_scores) score = torch.gather(next_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
# top_p
top_p = self.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
# sample # sample
probs = nn.functional.softmax(next_token_scores, dim=-1) probs = nn.functional.softmax(next_token_scores, dim=-1)
@ -545,20 +487,24 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# update generated ids, model inputs, and length for next step # update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul( unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
) )
# decoded, response, end_reason = decode_tokens(
# next_tokens,
# tokenizer,
# raw_text_len=0,
# context_length=0,
# errors="replace",
# )
# print(decoded)
# stop when each sentence is finished # stop when each sentence is finished
if unfinished_sequences.max() == 0: if unfinished_sequences.max() == 0:
this_peer_finished = True this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished: if this_peer_finished:
break break
return input_ids return input_ids