Compare commits
2 Commits
4c0991a409
...
69cb525ab0
Author | SHA1 | Date |
---|---|---|
Colin | 69cb525ab0 | |
Colin | 94ecf0f561 |
|
@ -4,16 +4,13 @@
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import importlib
|
|
||||||
import math
|
import math
|
||||||
import inspect
|
import inspect
|
||||||
import pathlib
|
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import warnings
|
|
||||||
|
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
|
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
|
||||||
|
@ -29,13 +26,8 @@ from transformers.modeling_outputs import (
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
try:
|
|
||||||
from einops import rearrange
|
|
||||||
except ImportError:
|
|
||||||
rearrange = None
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from einops import rearrange
|
||||||
SUPPORT_CUDA = torch.cuda.is_available()
|
|
||||||
|
|
||||||
from configuration_qwen import QWenConfig
|
from configuration_qwen import QWenConfig
|
||||||
from qwen_generation_utils import (
|
from qwen_generation_utils import (
|
||||||
|
@ -48,8 +40,6 @@ from qwen_generation_utils import (
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_SENTINEL = object()
|
|
||||||
|
|
||||||
|
|
||||||
class QWenAttention(nn.Module):
|
class QWenAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -71,11 +61,9 @@ class QWenAttention(nn.Module):
|
||||||
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
||||||
|
|
||||||
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
|
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
|
||||||
|
|
||||||
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
|
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
|
||||||
|
|
||||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||||
self.use_logn_attn = config.use_logn_attn
|
|
||||||
|
|
||||||
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
|
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
|
||||||
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
||||||
|
@ -104,38 +92,20 @@ class QWenAttention(nn.Module):
|
||||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
):
|
):
|
||||||
mixed_x_layer = self.c_attn(hidden_states)
|
mixed_x_layer = self.c_attn(hidden_states)
|
||||||
|
|
||||||
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
||||||
|
|
||||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
if rotary_pos_emb_list is not None:
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
cur_len = query.shape[1]
|
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
|
||||||
if len(rotary_pos_emb_list) == 1:
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
# Slice the pos emb for current inference
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||||
# Slice the pos emb for current inference
|
|
||||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
|
||||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
|
||||||
else:
|
|
||||||
query_list = []
|
|
||||||
key_list = []
|
|
||||||
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
|
|
||||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
|
||||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
|
||||||
# Slice the pos emb for current inference
|
|
||||||
query_list += [apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)]
|
|
||||||
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
|
|
||||||
query = torch.cat(query_list, dim=0)
|
|
||||||
key = torch.cat(key_list, dim=0)
|
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past[0], layer_past[1]
|
past_key, past_value = layer_past[0], layer_past[1]
|
||||||
|
@ -143,13 +113,10 @@ class QWenAttention(nn.Module):
|
||||||
key = torch.cat((past_key, key), dim=1)
|
key = torch.cat((past_key, key), dim=1)
|
||||||
value = torch.cat((past_value, value), dim=1)
|
value = torch.cat((past_value, value), dim=1)
|
||||||
|
|
||||||
if use_cache:
|
present = (key, value)
|
||||||
present = (key, value)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
key_size = key.size(1)
|
key_size = key.size(1)
|
||||||
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
if key_size > self.seq_length and not self.training:
|
||||||
seq_start = key.size(1) - query.size(1)
|
seq_start = key.size(1) - query.size(1)
|
||||||
seq_end = key.size(1)
|
seq_end = key.size(1)
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||||
|
@ -172,23 +139,20 @@ class QWenAttention(nn.Module):
|
||||||
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
||||||
else:
|
else:
|
||||||
attention_mask = causal_mask
|
attention_mask = causal_mask
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
||||||
|
|
||||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
attn_output = self.c_proj(context_layer)
|
attn_output = self.c_proj(context_layer)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, present)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class QWenMLP(nn.Module):
|
class QWenMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
|
|
||||||
self.w2 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
|
|
||||||
ff_dim_in = config.intermediate_size // 2
|
ff_dim_in = config.intermediate_size // 2
|
||||||
|
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
|
||||||
|
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
|
||||||
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
|
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
|
@ -213,7 +177,6 @@ class QWenBlock(nn.Module):
|
||||||
hidden_size,
|
hidden_size,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp = QWenMLP(config)
|
self.mlp = QWenMLP(config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -222,11 +185,6 @@ class QWenBlock(nn.Module):
|
||||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
):
|
):
|
||||||
layernorm_output = self.ln_1(hidden_states)
|
layernorm_output = self.ln_1(hidden_states)
|
||||||
|
|
||||||
|
@ -235,27 +193,17 @@ class QWenBlock(nn.Module):
|
||||||
rotary_pos_emb_list,
|
rotary_pos_emb_list,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
|
|
||||||
outputs = attn_outputs[1:]
|
outputs = attn_outputs[1:]
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
layernorm_input = attn_output + residual
|
layernorm_input = attn_output + residual
|
||||||
|
|
||||||
layernorm_output = self.ln_2(layernorm_input)
|
layernorm_output = self.ln_2(layernorm_input)
|
||||||
|
|
||||||
residual = layernorm_input
|
residual = layernorm_input
|
||||||
mlp_output = self.mlp(layernorm_output)
|
mlp_output = self.mlp(layernorm_output)
|
||||||
hidden_states = residual + mlp_output
|
hidden_states = residual + mlp_output
|
||||||
|
outputs = (hidden_states,) + outputs
|
||||||
if use_cache:
|
|
||||||
outputs = (hidden_states,) + outputs
|
|
||||||
else:
|
|
||||||
outputs = (hidden_states,) + outputs[1:]
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@ -314,14 +262,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
|
@ -343,7 +284,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype)
|
attention_mask = attention_mask.to(dtype=self.dtype)
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||||
|
|
||||||
encoder_attention_mask = None
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
|
@ -376,8 +316,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
presents = () if use_cache else None
|
presents = ()
|
||||||
all_self_attentions = () if output_attentions else None
|
|
||||||
all_hidden_states = None
|
all_hidden_states = None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
outputs = block(
|
outputs = block(
|
||||||
|
@ -385,27 +324,14 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
rotary_pos_emb_list=rotary_pos_emb_list,
|
rotary_pos_emb_list=rotary_pos_emb_list,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
presents = presents + (outputs[1],)
|
||||||
presents = presents + (outputs[1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
hidden_states = hidden_states.view(output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states
|
||||||
past_key_values=presents,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attentions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -434,7 +360,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -447,11 +372,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -459,10 +380,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
@ -498,19 +415,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
query: str,
|
query: str,
|
||||||
history: Optional[HistoryType],
|
history: Optional[HistoryType],
|
||||||
system: str = "You are a helpful assistant.",
|
system: str = "You are a helpful assistant.",
|
||||||
stream: Optional[bool] = _SENTINEL,
|
|
||||||
stop_words_ids: Optional[List[List[int]]] = None,
|
stop_words_ids: Optional[List[List[int]]] = None,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[str, HistoryType]:
|
) -> Tuple[str, HistoryType]:
|
||||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
generation_config = generation_config if generation_config is not None else self.generation_config
|
||||||
|
|
||||||
assert stream is _SENTINEL
|
|
||||||
assert generation_config.chat_format == "chatml"
|
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
else:
|
else:
|
||||||
# make a copy of the user's input such that is is left untouched
|
|
||||||
history = copy.deepcopy(history)
|
history = copy.deepcopy(history)
|
||||||
|
|
||||||
if stop_words_ids is None:
|
if stop_words_ids is None:
|
||||||
|
@ -536,7 +449,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = decode_tokens(
|
response = decode_tokens(
|
||||||
outputs[0],
|
outputs[0],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
@ -546,13 +458,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
verbose=False,
|
verbose=False,
|
||||||
errors="replace",
|
errors="replace",
|
||||||
)
|
)
|
||||||
|
|
||||||
# as history is a copy of the user inputs,
|
|
||||||
# we can always return the new turn to the user.
|
|
||||||
# separating input history and output history also enables the user
|
|
||||||
# to implement more complex history management
|
|
||||||
history.append((query, response))
|
history.append((query, response))
|
||||||
|
|
||||||
return response, history
|
return response, history
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -562,7 +468,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
synced_gpus: Optional[bool] = None,
|
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -592,7 +497,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
synced_gpus=synced_gpus,
|
|
||||||
assistant_model=assistant_model,
|
assistant_model=assistant_model,
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -605,7 +509,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
synced_gpus: Optional[bool] = None,
|
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||||
|
@ -637,21 +540,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config.pad_token_id = eos_token_id
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
# inputs_tensor has to be defined
|
|
||||||
# model_input_name is defined if model-specific keyword input is passed
|
|
||||||
# otherwise model_input_name is None
|
|
||||||
# all model-specific keyword inputs are removed from `model_kwargs`
|
|
||||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||||
inputs, generation_config.bos_token_id, model_kwargs
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
)
|
)
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
|
||||||
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
|
|
||||||
# generating the first new token or not, and we only want to use the embeddings for the first new token)
|
|
||||||
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
|
||||||
model_kwargs["use_cache"] = True
|
|
||||||
else:
|
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
|
||||||
|
|
||||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
@ -713,7 +605,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_scores=generation_config.output_scores,
|
output_scores=generation_config.output_scores,
|
||||||
synced_gpus=synced_gpus,
|
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -727,9 +618,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
synced_gpus: bool = False,
|
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
|
@ -744,9 +633,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_attentions = (
|
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
|
||||||
)
|
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = None
|
scores = None
|
||||||
|
@ -754,14 +640,14 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False
|
||||||
# auto-regressive generation
|
# auto-regressive generation
|
||||||
while True:
|
while True:
|
||||||
# prepare model inputs
|
# prepare model inputs
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||||
|
|
||||||
# forward pass to get next token
|
# forward pass to get next token
|
||||||
outputs = self(**model_inputs, output_attentions=output_attentions)
|
outputs = self(**model_inputs)
|
||||||
|
|
||||||
next_token_logits = outputs.logits[:, -1, :]
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
|
@ -832,8 +718,6 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||||
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
emb = rearrange(emb, "n d -> 1 n 1 d")
|
emb = rearrange(emb, "n d -> 1 n 1 d")
|
||||||
|
|
||||||
cos, sin = emb.cos(), emb.sin()
|
cos, sin = emb.cos(), emb.sin()
|
||||||
|
@ -846,8 +730,6 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def _rotate_half(x):
|
def _rotate_half(x):
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||||
x1, x2 = x.unbind(dim=-2)
|
x1, x2 = x.unbind(dim=-2)
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
Loading…
Reference in New Issue