Remove attention_mask
This commit is contained in:
		
							parent
							
								
									cd50c10e8c
								
							
						
					
					
						commit
						0458e7303c
					
				|  | @ -96,7 +96,6 @@ class QWenAttention(nn.Module): | ||||||
|         self, |         self, | ||||||
|         hidden_states: Optional[Tuple[torch.FloatTensor]], |         hidden_states: Optional[Tuple[torch.FloatTensor]], | ||||||
|         rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, |         rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, | ||||||
|         attention_mask: Optional[torch.FloatTensor] = None, |  | ||||||
|     ): |     ): | ||||||
|         mixed_x_layer = self.c_attn(hidden_states) |         mixed_x_layer = self.c_attn(hidden_states) | ||||||
|         query, key, value = mixed_x_layer.split(self.split_size, dim=2) |         query, key, value = mixed_x_layer.split(self.split_size, dim=2) | ||||||
|  | @ -120,32 +119,21 @@ class QWenAttention(nn.Module): | ||||||
|             query = query * logn_tensor.expand_as(query) |             query = query * logn_tensor.expand_as(query) | ||||||
| 
 | 
 | ||||||
|         key_size = key.size(1) |         key_size = key.size(1) | ||||||
|         if query.size(1) == key_size: |         causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view( | ||||||
|             causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view( |             1, 1, key_size, key_size | ||||||
|                 1, 1, key_size, key_size |         ) | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             causal_mask = None |  | ||||||
|         query = query.permute(0, 2, 1, 3) |         query = query.permute(0, 2, 1, 3) | ||||||
|         key = key.permute(0, 2, 1, 3) |         key = key.permute(0, 2, 1, 3) | ||||||
|         value = value.permute(0, 2, 1, 3) |         value = value.permute(0, 2, 1, 3) | ||||||
| 
 | 
 | ||||||
|         if attention_mask is not None: |  | ||||||
|             attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1) |  | ||||||
|             if causal_mask is not None: |  | ||||||
|                 attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) |  | ||||||
|         else: |  | ||||||
|             attention_mask = causal_mask |  | ||||||
| 
 |  | ||||||
|         # qk = query @ key.transpose(-2, -1) |         # qk = query @ key.transpose(-2, -1) | ||||||
|         # qk = qk[0] |         # qk = qk[0] | ||||||
|         # show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png") |         # prePath = "../generated/query_matmul_key/img/" | ||||||
|         # print("layer:" + str(self.index) + "  query.shape:"+ str(query.shape)) |         # show.DumpTensorToImage( | ||||||
|         # print("layer:" + str(self.index) + "  key.shape:"+ str(key.shape)) |         #     qk, prePath + "q_matmul_k_sequence_" + str(key_size) + "_layer_" + str(self.index) + ".png" | ||||||
|         # print("layer:" + str(self.index) + "  value.shape:"+ str(value.shape)) |         # ) | ||||||
|         # print("\n") |  | ||||||
| 
 | 
 | ||||||
|         attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) |         attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) | ||||||
|         context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) |         context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) | ||||||
|         attn_output = self.c_proj(context_layer) |         attn_output = self.c_proj(context_layer) | ||||||
| 
 | 
 | ||||||
|  | @ -189,15 +177,10 @@ class QWenBlock(nn.Module): | ||||||
|         self, |         self, | ||||||
|         hidden_states: Optional[Tuple[torch.FloatTensor]], |         hidden_states: Optional[Tuple[torch.FloatTensor]], | ||||||
|         rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, |         rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, | ||||||
|         attention_mask: Optional[torch.FloatTensor] = None, |  | ||||||
|     ): |     ): | ||||||
|         layernorm_output = self.ln_1(hidden_states) |         layernorm_output = self.ln_1(hidden_states) | ||||||
| 
 | 
 | ||||||
|         attn_outputs = self.attn( |         attn_outputs = self.attn(layernorm_output, rotary_pos_emb_list) | ||||||
|             layernorm_output, |  | ||||||
|             rotary_pos_emb_list, |  | ||||||
|             attention_mask=attention_mask, |  | ||||||
|         ) |  | ||||||
|         attn_output = attn_outputs[0] |         attn_output = attn_outputs[0] | ||||||
|         residual = hidden_states |         residual = hidden_states | ||||||
|         layernorm_input = attn_output + residual |         layernorm_input = attn_output + residual | ||||||
|  | @ -259,7 +242,6 @@ class QWenModel(QWenPreTrainedModel): | ||||||
|     def forward( |     def forward( | ||||||
|         self, |         self, | ||||||
|         input_ids: Optional[torch.LongTensor] = None, |         input_ids: Optional[torch.LongTensor] = None, | ||||||
|         attention_mask: Optional[torch.FloatTensor] = None, |  | ||||||
|         head_mask: Optional[torch.FloatTensor] = None, |         head_mask: Optional[torch.FloatTensor] = None, | ||||||
|         inputs_embeds: Optional[torch.FloatTensor] = None, |         inputs_embeds: Optional[torch.FloatTensor] = None, | ||||||
|     ): |     ): | ||||||
|  | @ -275,12 +257,6 @@ class QWenModel(QWenPreTrainedModel): | ||||||
|         else: |         else: | ||||||
|             raise ValueError("You have to specify either input_ids or inputs_embeds") |             raise ValueError("You have to specify either input_ids or inputs_embeds") | ||||||
| 
 | 
 | ||||||
|         if attention_mask is not None: |  | ||||||
|             attention_mask = attention_mask.view(batch_size, -1) |  | ||||||
|             attention_mask = attention_mask[:, None, None, :] |  | ||||||
|             attention_mask = attention_mask.to(dtype=self.dtype) |  | ||||||
|             attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min |  | ||||||
| 
 |  | ||||||
|         head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |         head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) | ||||||
| 
 | 
 | ||||||
|         if inputs_embeds is None: |         if inputs_embeds is None: | ||||||
|  | @ -295,15 +271,8 @@ class QWenModel(QWenPreTrainedModel): | ||||||
|             ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list |             ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list | ||||||
|         else: |         else: | ||||||
|             ntk_alpha_list = [] |             ntk_alpha_list = [] | ||||||
|             if attention_mask is not None and kv_seq_len > self.seq_length: |             ntk_alpha = self.get_ntk_alpha(kv_seq_len) | ||||||
|                 true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) |             ntk_alpha_list.append(ntk_alpha) | ||||||
|                 for i in range(hidden_states.size()[0]): |  | ||||||
|                     true_seq_len = true_seq_lens[i].item() |  | ||||||
|                     ntk_alpha = self.get_ntk_alpha(true_seq_len) |  | ||||||
|                     ntk_alpha_list.append(ntk_alpha) |  | ||||||
|             else: |  | ||||||
|                 ntk_alpha = self.get_ntk_alpha(kv_seq_len) |  | ||||||
|                 ntk_alpha_list.append(ntk_alpha) |  | ||||||
|         self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list |         self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list | ||||||
|         rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list] |         rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list] | ||||||
| 
 | 
 | ||||||
|  | @ -312,11 +281,7 @@ class QWenModel(QWenPreTrainedModel): | ||||||
| 
 | 
 | ||||||
|         all_hidden_states = None |         all_hidden_states = None | ||||||
|         for block in self.h: |         for block in self.h: | ||||||
|             hidden_states = block( |             hidden_states = block(hidden_states, rotary_pos_emb_list=rotary_pos_emb_list) | ||||||
|                 hidden_states, |  | ||||||
|                 rotary_pos_emb_list=rotary_pos_emb_list, |  | ||||||
|                 attention_mask=attention_mask, |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|         hidden_states = self.ln_f(hidden_states) |         hidden_states = self.ln_f(hidden_states) | ||||||
|         hidden_states = hidden_states.view(output_shape) |         hidden_states = hidden_states.view(output_shape) | ||||||
|  | @ -332,31 +297,18 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         self.post_init() |         self.post_init() | ||||||
| 
 | 
 | ||||||
|     def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs): |     def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs): | ||||||
|         if input_ids.size(0) == 1: |  | ||||||
|             attention_mask = None |  | ||||||
|         else: |  | ||||||
|             attention_mask = kwargs.get("attention_mask", None) |  | ||||||
| 
 |  | ||||||
|         model_inputs = {"input_ids": input_ids} |         model_inputs = {"input_ids": input_ids} | ||||||
| 
 |  | ||||||
|         model_inputs.update( |  | ||||||
|             { |  | ||||||
|                 "attention_mask": attention_mask, |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|         return model_inputs |         return model_inputs | ||||||
| 
 | 
 | ||||||
|     def forward( |     def forward( | ||||||
|         self, |         self, | ||||||
|         input_ids: Optional[torch.LongTensor] = None, |         input_ids: Optional[torch.LongTensor] = None, | ||||||
|         attention_mask: Optional[torch.FloatTensor] = None, |  | ||||||
|         head_mask: Optional[torch.FloatTensor] = None, |         head_mask: Optional[torch.FloatTensor] = None, | ||||||
|         inputs_embeds: Optional[torch.FloatTensor] = None, |         inputs_embeds: Optional[torch.FloatTensor] = None, | ||||||
|         labels: Optional[torch.LongTensor] = None, |         labels: Optional[torch.LongTensor] = None, | ||||||
|     ) -> Union[Tuple, CausalLMOutputWithPast]: |     ) -> Union[Tuple, CausalLMOutputWithPast]: | ||||||
|         transformer_outputs = self.transformer( |         transformer_outputs = self.transformer( | ||||||
|             input_ids, |             input_ids, | ||||||
|             attention_mask=attention_mask, |  | ||||||
|             head_mask=head_mask, |             head_mask=head_mask, | ||||||
|             inputs_embeds=inputs_embeds, |             inputs_embeds=inputs_embeds, | ||||||
|         ) |         ) | ||||||
|  | @ -418,6 +370,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         outputs = self.generate( |         outputs = self.generate( | ||||||
|             input_ids, |             input_ids, | ||||||
|             stop_words_ids=stop_words_ids, |             stop_words_ids=stop_words_ids, | ||||||
|  |             tokenizer=tokenizer, | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
|         decoded, response, end_reason = decode_tokens( |         decoded, response, end_reason = decode_tokens( | ||||||
|  | @ -434,6 +387,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         self, |         self, | ||||||
|         inputs: Optional[torch.Tensor] = None, |         inputs: Optional[torch.Tensor] = None, | ||||||
|         stop_words_ids=[], |         stop_words_ids=[], | ||||||
|  |         tokenizer=None, | ||||||
|         prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |         prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ) -> Union[GenerateOutput, torch.LongTensor]: |     ) -> Union[GenerateOutput, torch.LongTensor]: | ||||||
|  | @ -461,16 +415,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         ) |         ) | ||||||
|         # 4. Define other model kwargs |         # 4. Define other model kwargs | ||||||
| 
 | 
 | ||||||
|         accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) |  | ||||||
|         requires_attention_mask = "encoder_outputs" not in model_kwargs |  | ||||||
| 
 |  | ||||||
|         if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: |  | ||||||
|             model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( |  | ||||||
|                 inputs_tensor, |  | ||||||
|                 generation_config.pad_token_id, |  | ||||||
|                 generation_config.eos_token_id, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         # 5. Prepare `input_ids` which will be used for auto-regressive generation |         # 5. Prepare `input_ids` which will be used for auto-regressive generation | ||||||
|         input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") |         input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") | ||||||
| 
 | 
 | ||||||
|  | @ -551,6 +495,15 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|                 next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |                 next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |             # decoded, response, end_reason = decode_tokens( | ||||||
|  |             #     next_tokens, | ||||||
|  |             #     tokenizer, | ||||||
|  |             #     raw_text_len=0, | ||||||
|  |             #     context_length=0, | ||||||
|  |             #     errors="replace", | ||||||
|  |             # ) | ||||||
|  |             # print(decoded) | ||||||
|  | 
 | ||||||
|             # stop when each sentence is finished |             # stop when each sentence is finished | ||||||
|             if unfinished_sequences.max() == 0: |             if unfinished_sequences.max() == 0: | ||||||
|                 this_peer_finished = True |                 this_peer_finished = True | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue