Refien model code.

This commit is contained in:
Colin 2023-12-22 11:39:06 +08:00
parent 539392c843
commit 84938e565e
3 changed files with 9 additions and 32 deletions

View File

@ -93,11 +93,11 @@ class CoreAttention(torch.nn.Module):
self.attention_dropout = torch.nn.Dropout(config.attention_dropout) self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer, value_layer, attention_mask): def forward(self, query_layer, key_layer, value_layer):
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
] ]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: if query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention( context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, is_causal=True query_layer, key_layer, value_layer, is_causal=True
) )
@ -170,7 +170,7 @@ class SelfAttention(torch.nn.Module):
x_out2 = x_out2.flatten(3) x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1) return torch.cat((x_out2, x_pass), dim=-1)
def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None): def forward(self, hidden_states, rotary_pos_emb, kv_cache=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states) mixed_x_layer = self.query_key_value(hidden_states)
@ -250,9 +250,7 @@ class SelfAttention(torch.nn.Module):
# ================================== # ==================================
# core attention computation # core attention computation
# ================================== # ==================================
context_layer = self.core_attention( context_layer = self.core_attention(query_layer, key_layer, value_layer)
query_layer, key_layer, value_layer, attention_mask
)
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
# ================= # =================
@ -344,13 +342,13 @@ class GLMBlock(torch.nn.Module):
# MLP # MLP
self.mlp = MLP(config, device=device) self.mlp = MLP(config, device=device)
def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None): def forward(self, hidden_states, rotary_pos_emb, kv_cache=None):
# hidden_states: [s, b, h] # hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, kv_cache = self.self_attention( attention_output, kv_cache = self.self_attention(
layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache layernorm_output, rotary_pos_emb, kv_cache=kv_cache
) )
residual = hidden_states residual = hidden_states
@ -394,9 +392,7 @@ class GLMTransformer(torch.nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states,
attention_mask,
rotary_pos_emb, rotary_pos_emb,
kv_caches=None,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ):
kv_caches = [None for _ in range(self.num_layers)] kv_caches = [None for _ in range(self.num_layers)]
@ -405,7 +401,7 @@ class GLMTransformer(torch.nn.Module):
for index in range(self.num_layers): for index in range(self.num_layers):
layer = self.layers[index] layer = self.layers[index]
hidden_states, kv_cache = layer( hidden_states, kv_cache = layer(
hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index] hidden_states, rotary_pos_emb, kv_cache=kv_caches[index]
) )
if use_cache: if use_cache:
presents = presents + (kv_cache,) presents = presents + (kv_cache,)
@ -476,13 +472,8 @@ class ChatGLMModel(nn.Module):
self, self,
input_ids, input_ids,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False, return_last_logit: Optional[bool] = False,
): ):
output_hidden_states = ( output_hidden_states = (
@ -491,30 +482,18 @@ class ChatGLMModel(nn.Module):
else self.config.output_hidden_states else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
inputs_embeds = self.embedding(input_ids)
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
# Rotary positional embeddings # Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length) rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
# show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1) # show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1)
if position_ids is not None: rotary_pos_emb = rotary_pos_emb[position_ids]
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
hidden_states = self.encoder( hidden_states = self.encoder(
inputs_embeds, inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache, use_cache=use_cache,
) )
if return_last_logit: if return_last_logit:
@ -741,7 +720,6 @@ class ChatGLMForConditionalGeneration(nn.Module):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
model_inputs = { model_inputs = {
"input_ids": input_ids_in, "input_ids": input_ids_in,
"past_key_values": None,
"position_ids": position_ids_in, "position_ids": position_ids_in,
"return_last_logit": True, "return_last_logit": True,
"use_cache": use_cache, "use_cache": use_cache,
@ -749,7 +727,6 @@ class ChatGLMForConditionalGeneration(nn.Module):
logits = self.transformer( logits = self.transformer(
**model_inputs, **model_inputs,
return_dict=True,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
next_token_logits = logits[:, -1, :] next_token_logits = logits[:, -1, :]

BIN
plot.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 262 KiB

BIN
test.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 217 B