Remove attention_mask

This commit is contained in:
Colin 2024-01-20 18:08:20 +08:00
parent cd50c10e8c
commit 0458e7303c
1 changed files with 23 additions and 70 deletions

View File

@ -96,7 +96,6 @@ class QWenAttention(nn.Module):
self, self,
hidden_states: Optional[Tuple[torch.FloatTensor]], hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
): ):
mixed_x_layer = self.c_attn(hidden_states) mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2) query, key, value = mixed_x_layer.split(self.split_size, dim=2)
@ -120,32 +119,21 @@ class QWenAttention(nn.Module):
query = query * logn_tensor.expand_as(query) query = query * logn_tensor.expand_as(query)
key_size = key.size(1) key_size = key.size(1)
if query.size(1) == key_size: causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view( 1, 1, key_size, key_size
1, 1, key_size, key_size )
)
else:
causal_mask = None
query = query.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3)
if attention_mask is not None:
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
if causal_mask is not None:
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
else:
attention_mask = causal_mask
# qk = query @ key.transpose(-2, -1) # qk = query @ key.transpose(-2, -1)
# qk = qk[0] # qk = qk[0]
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png") # prePath = "../generated/query_matmul_key/img/"
# print("layer:" + str(self.index) + " query.shape:"+ str(query.shape)) # show.DumpTensorToImage(
# print("layer:" + str(self.index) + " key.shape:"+ str(key.shape)) # qk, prePath + "q_matmul_k_sequence_" + str(key_size) + "_layer_" + str(self.index) + ".png"
# print("layer:" + str(self.index) + " value.shape:"+ str(value.shape)) # )
# print("\n")
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer) attn_output = self.c_proj(context_layer)
@ -189,15 +177,10 @@ class QWenBlock(nn.Module):
self, self,
hidden_states: Optional[Tuple[torch.FloatTensor]], hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
): ):
layernorm_output = self.ln_1(hidden_states) layernorm_output = self.ln_1(hidden_states)
attn_outputs = self.attn( attn_outputs = self.attn(layernorm_output, rotary_pos_emb_list)
layernorm_output,
rotary_pos_emb_list,
attention_mask=attention_mask,
)
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
residual = hidden_states residual = hidden_states
layernorm_input = attn_output + residual layernorm_input = attn_output + residual
@ -259,7 +242,6 @@ class QWenModel(QWenPreTrainedModel):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
): ):
@ -275,12 +257,6 @@ class QWenModel(QWenPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None: if inputs_embeds is None:
@ -295,15 +271,8 @@ class QWenModel(QWenPreTrainedModel):
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
else: else:
ntk_alpha_list = [] ntk_alpha_list = []
if attention_mask is not None and kv_seq_len > self.seq_length: ntk_alpha = self.get_ntk_alpha(kv_seq_len)
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) ntk_alpha_list.append(ntk_alpha)
for i in range(hidden_states.size()[0]):
true_seq_len = true_seq_lens[i].item()
ntk_alpha = self.get_ntk_alpha(true_seq_len)
ntk_alpha_list.append(ntk_alpha)
else:
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
ntk_alpha_list.append(ntk_alpha)
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list] rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
@ -312,11 +281,7 @@ class QWenModel(QWenPreTrainedModel):
all_hidden_states = None all_hidden_states = None
for block in self.h: for block in self.h:
hidden_states = block( hidden_states = block(hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
hidden_states,
rotary_pos_emb_list=rotary_pos_emb_list,
attention_mask=attention_mask,
)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
@ -332,31 +297,18 @@ class QWenLMHeadModel(QWenPreTrainedModel):
self.post_init() self.post_init()
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
if input_ids.size(0) == 1:
attention_mask = None
else:
attention_mask = kwargs.get("attention_mask", None)
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
}
)
return model_inputs return model_inputs
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
@ -418,6 +370,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
outputs = self.generate( outputs = self.generate(
input_ids, input_ids,
stop_words_ids=stop_words_ids, stop_words_ids=stop_words_ids,
tokenizer=tokenizer,
**kwargs, **kwargs,
) )
decoded, response, end_reason = decode_tokens( decoded, response, end_reason = decode_tokens(
@ -434,6 +387,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
self, self,
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
stop_words_ids=[], stop_words_ids=[],
tokenizer=None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
@ -461,16 +415,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
) )
# 4. Define other model kwargs # 4. Define other model kwargs
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation # 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
@ -551,6 +495,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
) )
# decoded, response, end_reason = decode_tokens(
# next_tokens,
# tokenizer,
# raw_text_len=0,
# context_length=0,
# errors="replace",
# )
# print(decoded)
# stop when each sentence is finished # stop when each sentence is finished
if unfinished_sequences.max() == 0: if unfinished_sequences.max() == 0:
this_peer_finished = True this_peer_finished = True