diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index 653f2f1..bc36a2d 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -93,11 +93,11 @@ class CoreAttention(torch.nn.Module): self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - def forward(self, query_layer, key_layer, value_layer, attention_mask): + def forward(self, query_layer, key_layer, value_layer): query_layer, key_layer, value_layer = [ k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] ] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + if query_layer.shape[2] == key_layer.shape[2]: context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, is_causal=True ) @@ -170,7 +170,7 @@ class SelfAttention(torch.nn.Module): x_out2 = x_out2.flatten(3) return torch.cat((x_out2, x_pass), dim=-1) - def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None): + def forward(self, hidden_states, rotary_pos_emb, kv_cache=None): # hidden_states: [sq, b, h] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) @@ -250,9 +250,7 @@ class SelfAttention(torch.nn.Module): # ================================== # core attention computation # ================================== - context_layer = self.core_attention( - query_layer, key_layer, value_layer, attention_mask - ) + context_layer = self.core_attention(query_layer, key_layer, value_layer) # ================= # Output. [sq, b, h] # ================= @@ -344,13 +342,13 @@ class GLMBlock(torch.nn.Module): # MLP self.mlp = MLP(config, device=device) - def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None): + def forward(self, hidden_states, rotary_pos_emb, kv_cache=None): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, kv_cache = self.self_attention( - layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache + layernorm_output, rotary_pos_emb, kv_cache=kv_cache ) residual = hidden_states @@ -394,9 +392,7 @@ class GLMTransformer(torch.nn.Module): def forward( self, hidden_states, - attention_mask, rotary_pos_emb, - kv_caches=None, use_cache: Optional[bool] = True, ): kv_caches = [None for _ in range(self.num_layers)] @@ -405,7 +401,7 @@ class GLMTransformer(torch.nn.Module): for index in range(self.num_layers): layer = self.layers[index] hidden_states, kv_cache = layer( - hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index] + hidden_states, rotary_pos_emb, kv_cache=kv_caches[index] ) if use_cache: presents = presents + (kv_cache,) @@ -476,13 +472,8 @@ class ChatGLMModel(nn.Module): self, input_ids, position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, return_last_logit: Optional[bool] = False, ): output_hidden_states = ( @@ -491,30 +482,18 @@ class ChatGLMModel(nn.Module): else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) + inputs_embeds = self.embedding(input_ids) # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - # show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb[position_ids] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() hidden_states = self.encoder( inputs_embeds, - full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, ) if return_last_logit: @@ -741,7 +720,6 @@ class ChatGLMForConditionalGeneration(nn.Module): use_cache = use_cache if use_cache is not None else self.config.use_cache model_inputs = { "input_ids": input_ids_in, - "past_key_values": None, "position_ids": position_ids_in, "return_last_logit": True, "use_cache": use_cache, @@ -749,7 +727,6 @@ class ChatGLMForConditionalGeneration(nn.Module): logits = self.transformer( **model_inputs, - return_dict=True, output_hidden_states=output_hidden_states, ) next_token_logits = logits[:, -1, :] diff --git a/plot.png b/plot.png deleted file mode 100644 index fa2e598..0000000 Binary files a/plot.png and /dev/null differ diff --git a/test.png b/test.png deleted file mode 100644 index ba037b8..0000000 Binary files a/test.png and /dev/null differ