Re format qwen.
This commit is contained in:
parent
aa2d3b96c4
commit
4c0991a409
|
@ -50,6 +50,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -67,33 +68,22 @@ class QWenAttention(nn.Module):
|
|||
self.projection_size = config.kv_channels * config.num_attention_heads
|
||||
|
||||
assert self.projection_size % config.num_attention_heads == 0
|
||||
self.hidden_size_per_attention_head = (
|
||||
self.projection_size // config.num_attention_heads
|
||||
)
|
||||
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
||||
|
||||
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
|
||||
|
||||
self.c_proj = nn.Linear(
|
||||
config.hidden_size, self.projection_size, bias=not config.no_bias
|
||||
)
|
||||
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
|
||||
|
||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||
self.use_logn_attn = config.use_logn_attn
|
||||
|
||||
logn_list = [
|
||||
math.log(i, self.seq_length) if i > self.seq_length else 1
|
||||
for i in range(1, 32768)
|
||||
]
|
||||
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
|
||||
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
||||
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
||||
self.softmax_in_fp32 = (
|
||||
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
||||
)
|
||||
self.use_cache_kernel = (
|
||||
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
||||
)
|
||||
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
||||
self.use_cache_kernel = config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
||||
cache_dtype = torch.float
|
||||
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
||||
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
||||
|
@ -114,10 +104,6 @@ class QWenAttention(nn.Module):
|
|||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
):
|
||||
mixed_x_layer = self.c_attn(hidden_states)
|
||||
|
@ -146,9 +132,7 @@ class QWenAttention(nn.Module):
|
|||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
query_list += [
|
||||
apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)
|
||||
]
|
||||
query_list += [apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)]
|
||||
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
|
||||
query = torch.cat(query_list, dim=0)
|
||||
key = torch.cat(key_list, dim=0)
|
||||
|
@ -173,9 +157,9 @@ class QWenAttention(nn.Module):
|
|||
|
||||
key_size = key.size(1)
|
||||
if query.size(1) == key_size:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
||||
).view(1, 1, key_size, key_size)
|
||||
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
|
||||
1, 1, key_size, key_size
|
||||
)
|
||||
else:
|
||||
causal_mask = None
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
|
@ -185,24 +169,16 @@ class QWenAttention(nn.Module):
|
|||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
||||
if causal_mask is not None:
|
||||
attention_mask = attention_mask.masked_fill(
|
||||
~causal_mask, torch.finfo(query.dtype).min
|
||||
)
|
||||
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
||||
else:
|
||||
attention_mask = causal_mask
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask
|
||||
).transpose(1, 2)
|
||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
||||
|
||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = self.c_proj(context_layer)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
raise ValueError(
|
||||
"Cannot output attentions while using scaled_dot_product_attention"
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -210,12 +186,8 @@ class QWenAttention(nn.Module):
|
|||
class QWenMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(
|
||||
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
|
||||
)
|
||||
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
|
||||
self.w2 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
|
||||
ff_dim_in = config.intermediate_size // 2
|
||||
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
|
||||
|
||||
|
@ -263,10 +235,9 @@ class QWenBlock(nn.Module):
|
|||
rotary_pos_emb_list,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
attn_output = attn_outputs[0]
|
||||
|
||||
outputs = attn_outputs[1:]
|
||||
|
@ -299,8 +270,8 @@ class QWenPreTrainedModel(PreTrainedModel):
|
|||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
class QWenModel(QWenPreTrainedModel):
|
||||
|
||||
class QWenModel(QWenPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
|
@ -322,9 +293,7 @@ class QWenModel(QWenPreTrainedModel):
|
|||
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
|
||||
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
||||
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config) for i in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.h = nn.ModuleList([QWenBlock(config) for i in range(config.num_hidden_layers)])
|
||||
self.ln_f = RMSNorm(
|
||||
self.embed_dim,
|
||||
eps=config.layer_norm_epsilon,
|
||||
|
@ -348,19 +317,13 @@ class QWenModel(QWenPreTrainedModel):
|
|||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None
|
||||
output_attentions: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||
)
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
@ -399,12 +362,7 @@ class QWenModel(QWenPreTrainedModel):
|
|||
else:
|
||||
ntk_alpha_list = []
|
||||
if attention_mask is not None and kv_seq_len > self.seq_length:
|
||||
true_seq_lens = (
|
||||
attention_mask.squeeze(1)
|
||||
.squeeze(1)
|
||||
.eq(0)
|
||||
.sum(dim=-1, dtype=torch.int32)
|
||||
)
|
||||
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
|
||||
for i in range(hidden_states.size()[0]):
|
||||
true_seq_len = true_seq_lens[i].item()
|
||||
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
||||
|
@ -413,10 +371,7 @@ class QWenModel(QWenPreTrainedModel):
|
|||
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
||||
ntk_alpha_list.append(ntk_alpha)
|
||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||
rotary_pos_emb_list = [
|
||||
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
|
||||
for ntk_alpha in ntk_alpha_list
|
||||
]
|
||||
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
@ -425,7 +380,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
|
@ -443,9 +397,7 @@ class QWenModel(QWenPreTrainedModel):
|
|||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (
|
||||
outputs[2 if use_cache else 1],
|
||||
)
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
@ -458,7 +410,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
|
||||
|
||||
class QWenLMHeadModel(QWenPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -466,9 +417,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
|
@ -502,10 +451,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None
|
||||
output_attentions: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -515,7 +462,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
|
@ -527,9 +474,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||
)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
||||
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
|
@ -558,11 +503,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config: Optional[GenerationConfig] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[str, HistoryType]:
|
||||
generation_config = (
|
||||
generation_config
|
||||
if generation_config is not None
|
||||
else self.generation_config
|
||||
)
|
||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
||||
|
||||
assert stream is _SENTINEL
|
||||
assert generation_config.chat_format == "chatml"
|
||||
|
@ -587,9 +528,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
chat_format=generation_config.chat_format,
|
||||
)
|
||||
|
||||
stop_words_ids.extend(
|
||||
get_stop_words_ids(generation_config.chat_format, tokenizer)
|
||||
)
|
||||
stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer))
|
||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||
outputs = self.generate(
|
||||
input_ids,
|
||||
|
@ -622,19 +561,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[
|
||||
Callable[[int, torch.Tensor], List[int]]
|
||||
] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
generation_config = (
|
||||
generation_config
|
||||
if generation_config is not None
|
||||
else self.generation_config
|
||||
)
|
||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
||||
|
||||
# Process stop_words_ids.
|
||||
stop_words_ids = kwargs.pop("stop_words_ids", None)
|
||||
|
@ -671,9 +604,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[
|
||||
Callable[[int, torch.Tensor], List[int]]
|
||||
] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
|
@ -685,26 +616,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
self._validate_model_class()
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(
|
||||
**kwargs
|
||||
) # All unused kwargs must be model kwargs
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = (
|
||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
)
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
if (
|
||||
generation_config.pad_token_id is None
|
||||
and generation_config.eos_token_id is not None
|
||||
):
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
|
@ -713,9 +633,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(
|
||||
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
|
||||
)
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
# 3. Define model inputs
|
||||
|
@ -735,46 +653,27 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
else:
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
accepts_attention_mask = "attention_mask" in set(
|
||||
inspect.signature(self.forward).parameters.keys()
|
||||
)
|
||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if (
|
||||
model_kwargs.get("attention_mask", None) is None
|
||||
and requires_attention_mask
|
||||
and accepts_attention_mask
|
||||
):
|
||||
model_kwargs[
|
||||
"attention_mask"
|
||||
] = self._prepare_attention_mask_for_generation(
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
|
||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = (
|
||||
kwargs.get("max_length") is None
|
||||
and generation_config.max_length is not None
|
||||
)
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_length
|
||||
)
|
||||
self._validate_generated_length(
|
||||
generation_config, input_ids_length, has_default_max_length
|
||||
)
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 8. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
|
@ -835,53 +734,25 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
**model_kwargs,
|
||||
):
|
||||
# init values
|
||||
logits_processor = (
|
||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
)
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
logits_warper = (
|
||||
logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
)
|
||||
pad_token_id = (
|
||||
pad_token_id
|
||||
if pad_token_id is not None
|
||||
else self.generation_config.pad_token_id
|
||||
)
|
||||
eos_token_id = (
|
||||
eos_token_id
|
||||
if eos_token_id is not None
|
||||
else self.generation_config.eos_token_id
|
||||
)
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(input_ids.device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
output_scores = (
|
||||
output_scores
|
||||
if output_scores is not None
|
||||
else self.generation_config.output_scores
|
||||
)
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.generation_config.output_attentions
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
# auto-regressive generation
|
||||
|
@ -890,10 +761,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# forward pass to get next token
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
output_attentions=output_attentions
|
||||
)
|
||||
outputs = self(**model_inputs, output_attentions=output_attentions)
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
|
@ -908,27 +776,19 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=False
|
||||
)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
# stop when each sentence is finished
|
||||
|
@ -964,11 +824,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
|
||||
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
|
||||
self.inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
|
||||
/ self.dim
|
||||
)
|
||||
base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim)
|
||||
)
|
||||
self._seq_len_cached = max(2 * seqlen, 16)
|
||||
self._ntk_alpha_cached = ntk_alpha
|
||||
|
@ -998,7 +854,6 @@ def _rotate_half(x):
|
|||
|
||||
|
||||
def apply_rotary_pos_emb(t, freqs):
|
||||
|
||||
rot_dim = freqs[0].shape[-1]
|
||||
cos, sin = freqs
|
||||
t_float = t.float()
|
||||
|
|
Loading…
Reference in New Issue