Remove attention_mask
This commit is contained in:
parent
cd50c10e8c
commit
0458e7303c
|
@ -96,7 +96,6 @@ class QWenAttention(nn.Module):
|
|||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
mixed_x_layer = self.c_attn(hidden_states)
|
||||
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
||||
|
@ -120,32 +119,21 @@ class QWenAttention(nn.Module):
|
|||
query = query * logn_tensor.expand_as(query)
|
||||
|
||||
key_size = key.size(1)
|
||||
if query.size(1) == key_size:
|
||||
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
|
||||
1, 1, key_size, key_size
|
||||
)
|
||||
else:
|
||||
causal_mask = None
|
||||
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
|
||||
1, 1, key_size, key_size
|
||||
)
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
||||
if causal_mask is not None:
|
||||
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
||||
else:
|
||||
attention_mask = causal_mask
|
||||
|
||||
# qk = query @ key.transpose(-2, -1)
|
||||
# qk = qk[0]
|
||||
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
|
||||
# print("layer:" + str(self.index) + " query.shape:"+ str(query.shape))
|
||||
# print("layer:" + str(self.index) + " key.shape:"+ str(key.shape))
|
||||
# print("layer:" + str(self.index) + " value.shape:"+ str(value.shape))
|
||||
# print("\n")
|
||||
# prePath = "../generated/query_matmul_key/img/"
|
||||
# show.DumpTensorToImage(
|
||||
# qk, prePath + "q_matmul_k_sequence_" + str(key_size) + "_layer_" + str(self.index) + ".png"
|
||||
# )
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.c_proj(context_layer)
|
||||
|
||||
|
@ -189,15 +177,10 @@ class QWenBlock(nn.Module):
|
|||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
layernorm_output = self.ln_1(hidden_states)
|
||||
|
||||
attn_outputs = self.attn(
|
||||
layernorm_output,
|
||||
rotary_pos_emb_list,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
attn_outputs = self.attn(layernorm_output, rotary_pos_emb_list)
|
||||
attn_output = attn_outputs[0]
|
||||
residual = hidden_states
|
||||
layernorm_input = attn_output + residual
|
||||
|
@ -259,7 +242,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
|
@ -275,12 +257,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=self.dtype)
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
|
@ -295,15 +271,8 @@ class QWenModel(QWenPreTrainedModel):
|
|||
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
|
||||
else:
|
||||
ntk_alpha_list = []
|
||||
if attention_mask is not None and kv_seq_len > self.seq_length:
|
||||
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
|
||||
for i in range(hidden_states.size()[0]):
|
||||
true_seq_len = true_seq_lens[i].item()
|
||||
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
||||
ntk_alpha_list.append(ntk_alpha)
|
||||
else:
|
||||
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
||||
ntk_alpha_list.append(ntk_alpha)
|
||||
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
||||
ntk_alpha_list.append(ntk_alpha)
|
||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
|
||||
|
||||
|
@ -312,11 +281,7 @@ class QWenModel(QWenPreTrainedModel):
|
|||
|
||||
all_hidden_states = None
|
||||
for block in self.h:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
rotary_pos_emb_list=rotary_pos_emb_list,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = block(hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
@ -332,31 +297,18 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
self.post_init()
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
|
||||
if input_ids.size(0) == 1:
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
@ -418,6 +370,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
outputs = self.generate(
|
||||
input_ids,
|
||||
stop_words_ids=stop_words_ids,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
decoded, response, end_reason = decode_tokens(
|
||||
|
@ -434,6 +387,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
stop_words_ids=[],
|
||||
tokenizer=None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
|
@ -461,16 +415,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
)
|
||||
# 4. Define other model kwargs
|
||||
|
||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
|
||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
|
||||
|
@ -551,6 +495,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
# decoded, response, end_reason = decode_tokens(
|
||||
# next_tokens,
|
||||
# tokenizer,
|
||||
# raw_text_len=0,
|
||||
# context_length=0,
|
||||
# errors="replace",
|
||||
# )
|
||||
# print(decoded)
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
|
Loading…
Reference in New Issue