Update.
This commit is contained in:
parent
3db6e2c25b
commit
dff2b9231f
10
Readme.md
10
Readme.md
|
@ -99,10 +99,10 @@ Linear(intermediate_parallel) no bias -> [6, 1, 4096]
|
||||||
| | |
|
| | |
|
||||||
| expand expand -> [6, 1, 32, 128]
|
| expand expand -> [6, 1, 32, 128]
|
||||||
\ / |
|
\ / |
|
||||||
dot |
|
┏---- dot |
|
||||||
softmax /
|
┃ softmax /
|
||||||
\ /
|
attention┃ \ /
|
||||||
dot -> [1, 32, 6, 128] -> [6, 1, 4096]
|
┗---- dot -> [1, 32, 6, 128] -> [6, 1, 4096]
|
||||||
Linear -> [6, 1, 4096]
|
Linear -> [6, 1, 4096]
|
||||||
|
|
||||||
hidden_states: [s, b, h]
|
hidden_states: [s, b, h]
|
||||||
|
@ -146,7 +146,7 @@ return Linear(context_layer) -> [6, 1, 4096]
|
||||||
Add
|
Add
|
||||||
| \
|
| \
|
||||||
| RMSNorm
|
| RMSNorm
|
||||||
| mlp
|
| MLP
|
||||||
| dropout
|
| dropout
|
||||||
| /
|
| /
|
||||||
Add
|
Add
|
||||||
|
|
|
@ -316,7 +316,6 @@ class GLMBlock(torch.nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=config.torch_dtype,
|
dtype=config.torch_dtype,
|
||||||
)
|
)
|
||||||
# MLP
|
|
||||||
self.mlp = MLP(config, device=device)
|
self.mlp = MLP(config, device=device)
|
||||||
|
|
||||||
def forward(self, hidden_states, rotary_pos_emb):
|
def forward(self, hidden_states, rotary_pos_emb):
|
||||||
|
|
Loading…
Reference in New Issue