Refien model code.
This commit is contained in:
parent
539392c843
commit
84938e565e
|
@ -93,11 +93,11 @@ class CoreAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
||||||
|
|
||||||
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
def forward(self, query_layer, key_layer, value_layer):
|
||||||
query_layer, key_layer, value_layer = [
|
query_layer, key_layer, value_layer = [
|
||||||
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
|
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
|
||||||
]
|
]
|
||||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
if query_layer.shape[2] == key_layer.shape[2]:
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_layer, key_layer, value_layer, is_causal=True
|
query_layer, key_layer, value_layer, is_causal=True
|
||||||
)
|
)
|
||||||
|
@ -170,7 +170,7 @@ class SelfAttention(torch.nn.Module):
|
||||||
x_out2 = x_out2.flatten(3)
|
x_out2 = x_out2.flatten(3)
|
||||||
return torch.cat((x_out2, x_pass), dim=-1)
|
return torch.cat((x_out2, x_pass), dim=-1)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None):
|
def forward(self, hidden_states, rotary_pos_emb, kv_cache=None):
|
||||||
# hidden_states: [sq, b, h]
|
# hidden_states: [sq, b, h]
|
||||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||||
mixed_x_layer = self.query_key_value(hidden_states)
|
mixed_x_layer = self.query_key_value(hidden_states)
|
||||||
|
@ -250,9 +250,7 @@ class SelfAttention(torch.nn.Module):
|
||||||
# ==================================
|
# ==================================
|
||||||
# core attention computation
|
# core attention computation
|
||||||
# ==================================
|
# ==================================
|
||||||
context_layer = self.core_attention(
|
context_layer = self.core_attention(query_layer, key_layer, value_layer)
|
||||||
query_layer, key_layer, value_layer, attention_mask
|
|
||||||
)
|
|
||||||
# =================
|
# =================
|
||||||
# Output. [sq, b, h]
|
# Output. [sq, b, h]
|
||||||
# =================
|
# =================
|
||||||
|
@ -344,13 +342,13 @@ class GLMBlock(torch.nn.Module):
|
||||||
# MLP
|
# MLP
|
||||||
self.mlp = MLP(config, device=device)
|
self.mlp = MLP(config, device=device)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None):
|
def forward(self, hidden_states, rotary_pos_emb, kv_cache=None):
|
||||||
# hidden_states: [s, b, h]
|
# hidden_states: [s, b, h]
|
||||||
# Layer norm at the beginning of the transformer layer.
|
# Layer norm at the beginning of the transformer layer.
|
||||||
layernorm_output = self.input_layernorm(hidden_states)
|
layernorm_output = self.input_layernorm(hidden_states)
|
||||||
# Self attention.
|
# Self attention.
|
||||||
attention_output, kv_cache = self.self_attention(
|
attention_output, kv_cache = self.self_attention(
|
||||||
layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache
|
layernorm_output, rotary_pos_emb, kv_cache=kv_cache
|
||||||
)
|
)
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
|
@ -394,9 +392,7 @@ class GLMTransformer(torch.nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
|
||||||
rotary_pos_emb,
|
rotary_pos_emb,
|
||||||
kv_caches=None,
|
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
kv_caches = [None for _ in range(self.num_layers)]
|
kv_caches = [None for _ in range(self.num_layers)]
|
||||||
|
@ -405,7 +401,7 @@ class GLMTransformer(torch.nn.Module):
|
||||||
for index in range(self.num_layers):
|
for index in range(self.num_layers):
|
||||||
layer = self.layers[index]
|
layer = self.layers[index]
|
||||||
hidden_states, kv_cache = layer(
|
hidden_states, kv_cache = layer(
|
||||||
hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index]
|
hidden_states, rotary_pos_emb, kv_cache=kv_caches[index]
|
||||||
)
|
)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
presents = presents + (kv_cache,)
|
presents = presents + (kv_cache,)
|
||||||
|
@ -476,13 +472,8 @@ class ChatGLMModel(nn.Module):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
full_attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
return_last_logit: Optional[bool] = False,
|
return_last_logit: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
@ -491,30 +482,18 @@ class ChatGLMModel(nn.Module):
|
||||||
else self.config.output_hidden_states
|
else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embedding(input_ids)
|
inputs_embeds = self.embedding(input_ids)
|
||||||
|
|
||||||
# Rotary positional embeddings
|
# Rotary positional embeddings
|
||||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||||
|
|
||||||
# show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1)
|
# show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1)
|
||||||
|
|
||||||
if position_ids is not None:
|
|
||||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||||
else:
|
|
||||||
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
|
||||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||||
hidden_states = self.encoder(
|
hidden_states = self.encoder(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
full_attention_mask,
|
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
kv_caches=past_key_values,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
if return_last_logit:
|
if return_last_logit:
|
||||||
|
@ -741,7 +720,6 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
model_inputs = {
|
model_inputs = {
|
||||||
"input_ids": input_ids_in,
|
"input_ids": input_ids_in,
|
||||||
"past_key_values": None,
|
|
||||||
"position_ids": position_ids_in,
|
"position_ids": position_ids_in,
|
||||||
"return_last_logit": True,
|
"return_last_logit": True,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
|
@ -749,7 +727,6 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
|
|
||||||
logits = self.transformer(
|
logits = self.transformer(
|
||||||
**model_inputs,
|
**model_inputs,
|
||||||
return_dict=True,
|
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
)
|
)
|
||||||
next_token_logits = logits[:, -1, :]
|
next_token_logits = logits[:, -1, :]
|
||||||
|
|
Loading…
Reference in New Issue