This commit is contained in:
Colin 2023-12-29 20:39:40 +08:00
parent 3db6e2c25b
commit dff2b9231f
2 changed files with 5 additions and 6 deletions

View File

@ -99,10 +99,10 @@ Linear(intermediate_parallel) no bias -> [6, 1, 4096]
| | | | | |
| expand expand -> [6, 1, 32, 128] | expand expand -> [6, 1, 32, 128]
\ / | \ / |
dot | ┏---- dot |
softmax / softmax /
\ / attention┃ \ /
dot -> [1, 32, 6, 128] -> [6, 1, 4096] ┗---- dot -> [1, 32, 6, 128] -> [6, 1, 4096]
Linear -> [6, 1, 4096] Linear -> [6, 1, 4096]
hidden_states: [s, b, h] hidden_states: [s, b, h]
@ -146,7 +146,7 @@ return Linear(context_layer) -> [6, 1, 4096]
Add Add
| \ | \
| RMSNorm | RMSNorm
| mlp | MLP
| dropout | dropout
| / | /
Add Add

View File

@ -316,7 +316,6 @@ class GLMBlock(torch.nn.Module):
device=device, device=device,
dtype=config.torch_dtype, dtype=config.torch_dtype,
) )
# MLP
self.mlp = MLP(config, device=device) self.mlp = MLP(config, device=device)
def forward(self, hidden_states, rotary_pos_emb): def forward(self, hidden_states, rotary_pos_emb):