From 0458e7303cd185170488aeb52fd9fe43054200ec Mon Sep 17 00:00:00 2001 From: Colin Date: Sat, 20 Jan 2024 18:08:20 +0800 Subject: [PATCH] Remove attention_mask --- qwen/modeling_qwen.py | 93 +++++++++++-------------------------------- 1 file changed, 23 insertions(+), 70 deletions(-) diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 6c75d38..b9bcc2b 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -96,7 +96,6 @@ class QWenAttention(nn.Module): self, hidden_states: Optional[Tuple[torch.FloatTensor]], rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, ): mixed_x_layer = self.c_attn(hidden_states) query, key, value = mixed_x_layer.split(self.split_size, dim=2) @@ -120,32 +119,21 @@ class QWenAttention(nn.Module): query = query * logn_tensor.expand_as(query) key_size = key.size(1) - if query.size(1) == key_size: - causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view( - 1, 1, key_size, key_size - ) - else: - causal_mask = None + causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view( + 1, 1, key_size, key_size + ) query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) - if attention_mask is not None: - attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1) - if causal_mask is not None: - attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) - else: - attention_mask = causal_mask - # qk = query @ key.transpose(-2, -1) # qk = qk[0] - # show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png") - # print("layer:" + str(self.index) + " query.shape:"+ str(query.shape)) - # print("layer:" + str(self.index) + " key.shape:"+ str(key.shape)) - # print("layer:" + str(self.index) + " value.shape:"+ str(value.shape)) - # print("\n") + # prePath = "../generated/query_matmul_key/img/" + # show.DumpTensorToImage( + # qk, prePath + "q_matmul_k_sequence_" + str(key_size) + "_layer_" + str(self.index) + ".png" + # ) - attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) + attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(context_layer) @@ -189,15 +177,10 @@ class QWenBlock(nn.Module): self, hidden_states: Optional[Tuple[torch.FloatTensor]], rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, ): layernorm_output = self.ln_1(hidden_states) - attn_outputs = self.attn( - layernorm_output, - rotary_pos_emb_list, - attention_mask=attention_mask, - ) + attn_outputs = self.attn(layernorm_output, rotary_pos_emb_list) attn_output = attn_outputs[0] residual = hidden_states layernorm_input = attn_output + residual @@ -259,7 +242,6 @@ class QWenModel(QWenPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ): @@ -275,12 +257,6 @@ class QWenModel(QWenPreTrainedModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if attention_mask is not None: - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) if inputs_embeds is None: @@ -295,15 +271,8 @@ class QWenModel(QWenPreTrainedModel): ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list else: ntk_alpha_list = [] - if attention_mask is not None and kv_seq_len > self.seq_length: - true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) - for i in range(hidden_states.size()[0]): - true_seq_len = true_seq_lens[i].item() - ntk_alpha = self.get_ntk_alpha(true_seq_len) - ntk_alpha_list.append(ntk_alpha) - else: - ntk_alpha = self.get_ntk_alpha(kv_seq_len) - ntk_alpha_list.append(ntk_alpha) + ntk_alpha = self.get_ntk_alpha(kv_seq_len) + ntk_alpha_list.append(ntk_alpha) self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list] @@ -312,11 +281,7 @@ class QWenModel(QWenPreTrainedModel): all_hidden_states = None for block in self.h: - hidden_states = block( - hidden_states, - rotary_pos_emb_list=rotary_pos_emb_list, - attention_mask=attention_mask, - ) + hidden_states = block(hidden_states, rotary_pos_emb_list=rotary_pos_emb_list) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) @@ -332,31 +297,18 @@ class QWenLMHeadModel(QWenPreTrainedModel): self.post_init() def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs): - if input_ids.size(0) == 1: - attention_mask = None - else: - attention_mask = kwargs.get("attention_mask", None) - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "attention_mask": attention_mask, - } - ) return model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: transformer_outputs = self.transformer( input_ids, - attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, ) @@ -418,6 +370,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): outputs = self.generate( input_ids, stop_words_ids=stop_words_ids, + tokenizer=tokenizer, **kwargs, ) decoded, response, end_reason = decode_tokens( @@ -434,6 +387,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): self, inputs: Optional[torch.Tensor] = None, stop_words_ids=[], + tokenizer=None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: @@ -461,16 +415,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): ) # 4. Define other model kwargs - accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) - requires_attention_mask = "encoder_outputs" not in model_kwargs - - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, - generation_config.pad_token_id, - generation_config.eos_token_id, - ) - # 5. Prepare `input_ids` which will be used for auto-regressive generation input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") @@ -551,6 +495,15 @@ class QWenLMHeadModel(QWenPreTrainedModel): next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) + # decoded, response, end_reason = decode_tokens( + # next_tokens, + # tokenizer, + # raw_text_len=0, + # context_length=0, + # errors="replace", + # ) + # print(decoded) + # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True