Re format qwen.
This commit is contained in:
parent
aa2d3b96c4
commit
4c0991a409
|
@ -50,6 +50,7 @@ logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_SENTINEL = object()
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
class QWenAttention(nn.Module):
|
class QWenAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -67,33 +68,22 @@ class QWenAttention(nn.Module):
|
||||||
self.projection_size = config.kv_channels * config.num_attention_heads
|
self.projection_size = config.kv_channels * config.num_attention_heads
|
||||||
|
|
||||||
assert self.projection_size % config.num_attention_heads == 0
|
assert self.projection_size % config.num_attention_heads == 0
|
||||||
self.hidden_size_per_attention_head = (
|
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
||||||
self.projection_size // config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
|
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
|
||||||
|
|
||||||
self.c_proj = nn.Linear(
|
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
|
||||||
config.hidden_size, self.projection_size, bias=not config.no_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||||
self.use_logn_attn = config.use_logn_attn
|
self.use_logn_attn = config.use_logn_attn
|
||||||
|
|
||||||
logn_list = [
|
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
|
||||||
math.log(i, self.seq_length) if i > self.seq_length else 1
|
|
||||||
for i in range(1, 32768)
|
|
||||||
]
|
|
||||||
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
||||||
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
||||||
|
|
||||||
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
||||||
self.softmax_in_fp32 = (
|
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
||||||
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
self.use_cache_kernel = config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
||||||
)
|
|
||||||
self.use_cache_kernel = (
|
|
||||||
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
|
||||||
)
|
|
||||||
cache_dtype = torch.float
|
cache_dtype = torch.float
|
||||||
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
||||||
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
||||||
|
@ -114,10 +104,6 @@ class QWenAttention(nn.Module):
|
||||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
mixed_x_layer = self.c_attn(hidden_states)
|
mixed_x_layer = self.c_attn(hidden_states)
|
||||||
|
@ -146,9 +132,7 @@ class QWenAttention(nn.Module):
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||||
# Slice the pos emb for current inference
|
# Slice the pos emb for current inference
|
||||||
query_list += [
|
query_list += [apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)]
|
||||||
apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)
|
|
||||||
]
|
|
||||||
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
|
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
|
||||||
query = torch.cat(query_list, dim=0)
|
query = torch.cat(query_list, dim=0)
|
||||||
key = torch.cat(key_list, dim=0)
|
key = torch.cat(key_list, dim=0)
|
||||||
|
@ -173,9 +157,9 @@ class QWenAttention(nn.Module):
|
||||||
|
|
||||||
key_size = key.size(1)
|
key_size = key.size(1)
|
||||||
if query.size(1) == key_size:
|
if query.size(1) == key_size:
|
||||||
causal_mask = torch.tril(
|
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
|
||||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
1, 1, key_size, key_size
|
||||||
).view(1, 1, key_size, key_size)
|
)
|
||||||
else:
|
else:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
|
@ -185,24 +169,16 @@ class QWenAttention(nn.Module):
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
||||||
if causal_mask is not None:
|
if causal_mask is not None:
|
||||||
attention_mask = attention_mask.masked_fill(
|
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
||||||
~causal_mask, torch.finfo(query.dtype).min
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
attention_mask = causal_mask
|
attention_mask = causal_mask
|
||||||
attn_output = F.scaled_dot_product_attention(
|
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
||||||
query, key, value, attn_mask=attention_mask
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
attn_output = self.c_proj(context_layer)
|
attn_output = self.c_proj(context_layer)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, present)
|
||||||
if output_attentions:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot output attentions while using scaled_dot_product_attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -210,12 +186,8 @@ class QWenAttention(nn.Module):
|
||||||
class QWenMLP(nn.Module):
|
class QWenMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.w1 = nn.Linear(
|
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
|
||||||
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
|
self.w2 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
|
||||||
)
|
|
||||||
self.w2 = nn.Linear(
|
|
||||||
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
|
|
||||||
)
|
|
||||||
ff_dim_in = config.intermediate_size // 2
|
ff_dim_in = config.intermediate_size // 2
|
||||||
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
|
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
|
||||||
|
|
||||||
|
@ -263,10 +235,9 @@ class QWenBlock(nn.Module):
|
||||||
rotary_pos_emb_list,
|
rotary_pos_emb_list,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
|
|
||||||
outputs = attn_outputs[1:]
|
outputs = attn_outputs[1:]
|
||||||
|
@ -299,8 +270,8 @@ class QWenPreTrainedModel(PreTrainedModel):
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
|
|
||||||
class QWenModel(QWenPreTrainedModel):
|
|
||||||
|
|
||||||
|
class QWenModel(QWenPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
@ -322,9 +293,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
|
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
|
||||||
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
||||||
|
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList([QWenBlock(config) for i in range(config.num_hidden_layers)])
|
||||||
[QWenBlock(config) for i in range(config.num_hidden_layers)]
|
|
||||||
)
|
|
||||||
self.ln_f = RMSNorm(
|
self.ln_f = RMSNorm(
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
|
@ -348,19 +317,13 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None
|
output_attentions: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
output_attentions = (
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError(
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
"You cannot specify both input_ids and inputs_embeds at the same time"
|
|
||||||
)
|
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
@ -399,12 +362,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
ntk_alpha_list = []
|
ntk_alpha_list = []
|
||||||
if attention_mask is not None and kv_seq_len > self.seq_length:
|
if attention_mask is not None and kv_seq_len > self.seq_length:
|
||||||
true_seq_lens = (
|
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
|
||||||
attention_mask.squeeze(1)
|
|
||||||
.squeeze(1)
|
|
||||||
.eq(0)
|
|
||||||
.sum(dim=-1, dtype=torch.int32)
|
|
||||||
)
|
|
||||||
for i in range(hidden_states.size()[0]):
|
for i in range(hidden_states.size()[0]):
|
||||||
true_seq_len = true_seq_lens[i].item()
|
true_seq_len = true_seq_lens[i].item()
|
||||||
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
||||||
|
@ -413,10 +371,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
||||||
ntk_alpha_list.append(ntk_alpha)
|
ntk_alpha_list.append(ntk_alpha)
|
||||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||||
rotary_pos_emb_list = [
|
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
|
||||||
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
|
|
||||||
for ntk_alpha in ntk_alpha_list
|
|
||||||
]
|
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
@ -425,7 +380,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_hidden_states = None
|
all_hidden_states = None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
|
@ -443,9 +397,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
presents = presents + (outputs[1],)
|
presents = presents + (outputs[1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
outputs[2 if use_cache else 1],
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
hidden_states = hidden_states.view(output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
@ -458,7 +410,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class QWenLMHeadModel(QWenPreTrainedModel):
|
class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
@ -466,9 +417,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
||||||
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
|
||||||
):
|
|
||||||
if past_key_values:
|
if past_key_values:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
@ -502,10 +451,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None
|
output_attentions: Optional[bool] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
@ -515,7 +462,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
@ -527,9 +474,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
||||||
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
@ -558,11 +503,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[str, HistoryType]:
|
) -> Tuple[str, HistoryType]:
|
||||||
generation_config = (
|
generation_config = generation_config if generation_config is not None else self.generation_config
|
||||||
generation_config
|
|
||||||
if generation_config is not None
|
|
||||||
else self.generation_config
|
|
||||||
)
|
|
||||||
|
|
||||||
assert stream is _SENTINEL
|
assert stream is _SENTINEL
|
||||||
assert generation_config.chat_format == "chatml"
|
assert generation_config.chat_format == "chatml"
|
||||||
|
@ -587,9 +528,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
chat_format=generation_config.chat_format,
|
chat_format=generation_config.chat_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_words_ids.extend(
|
stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer))
|
||||||
get_stop_words_ids(generation_config.chat_format, tokenizer)
|
|
||||||
)
|
|
||||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||||
outputs = self.generate(
|
outputs = self.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -622,19 +561,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
Callable[[int, torch.Tensor], List[int]]
|
|
||||||
] = None,
|
|
||||||
synced_gpus: Optional[bool] = None,
|
synced_gpus: Optional[bool] = None,
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
generation_config = (
|
generation_config = generation_config if generation_config is not None else self.generation_config
|
||||||
generation_config
|
|
||||||
if generation_config is not None
|
|
||||||
else self.generation_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process stop_words_ids.
|
# Process stop_words_ids.
|
||||||
stop_words_ids = kwargs.pop("stop_words_ids", None)
|
stop_words_ids = kwargs.pop("stop_words_ids", None)
|
||||||
|
@ -671,9 +604,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
Callable[[int, torch.Tensor], List[int]]
|
|
||||||
] = None,
|
|
||||||
synced_gpus: Optional[bool] = None,
|
synced_gpus: Optional[bool] = None,
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
|
@ -685,26 +616,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
self._validate_model_class()
|
self._validate_model_class()
|
||||||
|
|
||||||
generation_config = copy.deepcopy(generation_config)
|
generation_config = copy.deepcopy(generation_config)
|
||||||
model_kwargs = generation_config.update(
|
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||||
**kwargs
|
|
||||||
) # All unused kwargs must be model kwargs
|
|
||||||
generation_config.validate()
|
generation_config.validate()
|
||||||
self._validate_model_kwargs(model_kwargs.copy())
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = (
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
)
|
|
||||||
stopping_criteria = (
|
|
||||||
stopping_criteria
|
|
||||||
if stopping_criteria is not None
|
|
||||||
else StoppingCriteriaList()
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||||
generation_config.pad_token_id is None
|
|
||||||
and generation_config.eos_token_id is not None
|
|
||||||
):
|
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
|
@ -713,9 +633,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
eos_token_id = generation_config.eos_token_id
|
eos_token_id = generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, list):
|
if isinstance(eos_token_id, list):
|
||||||
eos_token_id = eos_token_id[0]
|
eos_token_id = eos_token_id[0]
|
||||||
logger.warning(
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||||
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
|
|
||||||
)
|
|
||||||
generation_config.pad_token_id = eos_token_id
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
|
@ -735,46 +653,27 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
|
||||||
accepts_attention_mask = "attention_mask" in set(
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
inspect.signature(self.forward).parameters.keys()
|
|
||||||
)
|
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
if (
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||||
model_kwargs.get("attention_mask", None) is None
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
and requires_attention_mask
|
|
||||||
and accepts_attention_mask
|
|
||||||
):
|
|
||||||
model_kwargs[
|
|
||||||
"attention_mask"
|
|
||||||
] = self._prepare_attention_mask_for_generation(
|
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
generation_config.pad_token_id,
|
generation_config.pad_token_id,
|
||||||
generation_config.eos_token_id,
|
generation_config.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
input_ids = (
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
inputs_tensor
|
|
||||||
if model_input_name == "input_ids"
|
|
||||||
else model_kwargs.pop("input_ids")
|
|
||||||
)
|
|
||||||
|
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
streamer.put(input_ids.cpu())
|
streamer.put(input_ids.cpu())
|
||||||
|
|
||||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
has_default_max_length = (
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||||
kwargs.get("max_length") is None
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||||
and generation_config.max_length is not None
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||||
)
|
|
||||||
generation_config.max_length = (
|
|
||||||
generation_config.max_new_tokens + input_ids_length
|
|
||||||
)
|
|
||||||
self._validate_generated_length(
|
|
||||||
generation_config, input_ids_length, has_default_max_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# 8. prepare distribution pre_processing samplers
|
# 8. prepare distribution pre_processing samplers
|
||||||
logits_processor = self._get_logits_processor(
|
logits_processor = self._get_logits_processor(
|
||||||
|
@ -835,53 +734,25 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
# init values
|
# init values
|
||||||
logits_processor = (
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
)
|
|
||||||
stopping_criteria = (
|
|
||||||
stopping_criteria
|
|
||||||
if stopping_criteria is not None
|
|
||||||
else StoppingCriteriaList()
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_warper = (
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
logits_warper if logits_warper is not None else LogitsProcessorList()
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
)
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
pad_token_id = (
|
|
||||||
pad_token_id
|
|
||||||
if pad_token_id is not None
|
|
||||||
else self.generation_config.pad_token_id
|
|
||||||
)
|
|
||||||
eos_token_id = (
|
|
||||||
eos_token_id
|
|
||||||
if eos_token_id is not None
|
|
||||||
else self.generation_config.eos_token_id
|
|
||||||
)
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = (
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
torch.tensor(eos_token_id).to(input_ids.device)
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
if eos_token_id is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
output_scores = (
|
|
||||||
output_scores
|
|
||||||
if output_scores is not None
|
|
||||||
else self.generation_config.output_scores
|
|
||||||
)
|
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
if output_attentions is not None
|
|
||||||
else self.generation_config.output_attentions
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = None
|
scores = None
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
unfinished_sequences = torch.ones(
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
|
||||||
)
|
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
# auto-regressive generation
|
# auto-regressive generation
|
||||||
|
@ -890,10 +761,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||||
|
|
||||||
# forward pass to get next token
|
# forward pass to get next token
|
||||||
outputs = self(
|
outputs = self(**model_inputs, output_attentions=output_attentions)
|
||||||
**model_inputs,
|
|
||||||
output_attentions=output_attentions
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_logits = outputs.logits[:, -1, :]
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
|
@ -908,27 +776,19 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
# finished sentences should have their next token be a padding token
|
# finished sentences should have their next token be a padding token
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
if pad_token_id is None:
|
if pad_token_id is None:
|
||||||
raise ValueError(
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||||
)
|
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
||||||
1 - unfinished_sequences
|
|
||||||
)
|
|
||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
# update generated ids, model inputs, and length for next step
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
streamer.put(next_tokens.cpu())
|
streamer.put(next_tokens.cpu())
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
||||||
outputs, model_kwargs, is_encoder_decoder=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id_tensor is not None:
|
if eos_token_id_tensor is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
|
||||||
.prod(dim=0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
|
@ -964,11 +824,7 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
|
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
|
||||||
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
|
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
|
||||||
self.inv_freq = 1.0 / (
|
self.inv_freq = 1.0 / (
|
||||||
base
|
base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim)
|
||||||
** (
|
|
||||||
torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
|
|
||||||
/ self.dim
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self._seq_len_cached = max(2 * seqlen, 16)
|
self._seq_len_cached = max(2 * seqlen, 16)
|
||||||
self._ntk_alpha_cached = ntk_alpha
|
self._ntk_alpha_cached = ntk_alpha
|
||||||
|
@ -998,7 +854,6 @@ def _rotate_half(x):
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(t, freqs):
|
def apply_rotary_pos_emb(t, freqs):
|
||||||
|
|
||||||
rot_dim = freqs[0].shape[-1]
|
rot_dim = freqs[0].shape[-1]
|
||||||
cos, sin = freqs
|
cos, sin = freqs
|
||||||
t_float = t.float()
|
t_float = t.float()
|
||||||
|
|
Loading…
Reference in New Issue