Remove attention_mask
This commit is contained in:
parent
cd50c10e8c
commit
0458e7303c
|
@ -96,7 +96,6 @@ class QWenAttention(nn.Module):
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
):
|
):
|
||||||
mixed_x_layer = self.c_attn(hidden_states)
|
mixed_x_layer = self.c_attn(hidden_states)
|
||||||
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
||||||
|
@ -120,32 +119,21 @@ class QWenAttention(nn.Module):
|
||||||
query = query * logn_tensor.expand_as(query)
|
query = query * logn_tensor.expand_as(query)
|
||||||
|
|
||||||
key_size = key.size(1)
|
key_size = key.size(1)
|
||||||
if query.size(1) == key_size:
|
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
|
||||||
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
|
1, 1, key_size, key_size
|
||||||
1, 1, key_size, key_size
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
causal_mask = None
|
|
||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
key = key.permute(0, 2, 1, 3)
|
key = key.permute(0, 2, 1, 3)
|
||||||
value = value.permute(0, 2, 1, 3)
|
value = value.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
|
||||||
if causal_mask is not None:
|
|
||||||
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
|
||||||
else:
|
|
||||||
attention_mask = causal_mask
|
|
||||||
|
|
||||||
# qk = query @ key.transpose(-2, -1)
|
# qk = query @ key.transpose(-2, -1)
|
||||||
# qk = qk[0]
|
# qk = qk[0]
|
||||||
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
|
# prePath = "../generated/query_matmul_key/img/"
|
||||||
# print("layer:" + str(self.index) + " query.shape:"+ str(query.shape))
|
# show.DumpTensorToImage(
|
||||||
# print("layer:" + str(self.index) + " key.shape:"+ str(key.shape))
|
# qk, prePath + "q_matmul_k_sequence_" + str(key_size) + "_layer_" + str(self.index) + ".png"
|
||||||
# print("layer:" + str(self.index) + " value.shape:"+ str(value.shape))
|
# )
|
||||||
# print("\n")
|
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
||||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||||
attn_output = self.c_proj(context_layer)
|
attn_output = self.c_proj(context_layer)
|
||||||
|
|
||||||
|
@ -189,15 +177,10 @@ class QWenBlock(nn.Module):
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
):
|
):
|
||||||
layernorm_output = self.ln_1(hidden_states)
|
layernorm_output = self.ln_1(hidden_states)
|
||||||
|
|
||||||
attn_outputs = self.attn(
|
attn_outputs = self.attn(layernorm_output, rotary_pos_emb_list)
|
||||||
layernorm_output,
|
|
||||||
rotary_pos_emb_list,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
)
|
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
layernorm_input = attn_output + residual
|
layernorm_input = attn_output + residual
|
||||||
|
@ -259,7 +242,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
):
|
):
|
||||||
|
@ -275,12 +257,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype)
|
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
|
||||||
|
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
|
@ -295,15 +271,8 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
|
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
|
||||||
else:
|
else:
|
||||||
ntk_alpha_list = []
|
ntk_alpha_list = []
|
||||||
if attention_mask is not None and kv_seq_len > self.seq_length:
|
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
||||||
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
|
ntk_alpha_list.append(ntk_alpha)
|
||||||
for i in range(hidden_states.size()[0]):
|
|
||||||
true_seq_len = true_seq_lens[i].item()
|
|
||||||
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
|
||||||
ntk_alpha_list.append(ntk_alpha)
|
|
||||||
else:
|
|
||||||
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
|
||||||
ntk_alpha_list.append(ntk_alpha)
|
|
||||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||||
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
|
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
|
||||||
|
|
||||||
|
@ -312,11 +281,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
all_hidden_states = None
|
all_hidden_states = None
|
||||||
for block in self.h:
|
for block in self.h:
|
||||||
hidden_states = block(
|
hidden_states = block(hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
|
||||||
hidden_states,
|
|
||||||
rotary_pos_emb_list=rotary_pos_emb_list,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
hidden_states = hidden_states.view(output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
@ -332,31 +297,18 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
|
||||||
if input_ids.size(0) == 1:
|
|
||||||
attention_mask = None
|
|
||||||
else:
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
|
||||||
|
|
||||||
model_inputs = {"input_ids": input_ids}
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
model_inputs.update(
|
|
||||||
{
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
@ -418,6 +370,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
outputs = self.generate(
|
outputs = self.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
stop_words_ids=stop_words_ids,
|
stop_words_ids=stop_words_ids,
|
||||||
|
tokenizer=tokenizer,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
decoded, response, end_reason = decode_tokens(
|
decoded, response, end_reason = decode_tokens(
|
||||||
|
@ -434,6 +387,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
stop_words_ids=[],
|
stop_words_ids=[],
|
||||||
|
tokenizer=None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
|
@ -461,16 +415,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
)
|
)
|
||||||
# 4. Define other model kwargs
|
# 4. Define other model kwargs
|
||||||
|
|
||||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
|
||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
|
||||||
inputs_tensor,
|
|
||||||
generation_config.pad_token_id,
|
|
||||||
generation_config.eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
|
|
||||||
|
@ -551,6 +495,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# decoded, response, end_reason = decode_tokens(
|
||||||
|
# next_tokens,
|
||||||
|
# tokenizer,
|
||||||
|
# raw_text_len=0,
|
||||||
|
# context_length=0,
|
||||||
|
# errors="replace",
|
||||||
|
# )
|
||||||
|
# print(decoded)
|
||||||
|
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
if unfinished_sequences.max() == 0:
|
if unfinished_sequences.max() == 0:
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
Loading…
Reference in New Issue