diff --git a/Readme.md b/Readme.md index f1c42f4..190e9da 100644 --- a/Readme.md +++ b/Readme.md @@ -99,10 +99,10 @@ Linear(intermediate_parallel) no bias -> [6, 1, 4096] | | | | expand expand -> [6, 1, 32, 128] \ / | - dot | - softmax / - \ / - dot -> [1, 32, 6, 128] -> [6, 1, 4096] + ┏---- dot | + ┃ softmax / +attention┃ \ / + ┗---- dot -> [1, 32, 6, 128] -> [6, 1, 4096] Linear -> [6, 1, 4096] hidden_states: [s, b, h] @@ -146,7 +146,7 @@ return Linear(context_layer) -> [6, 1, 4096] Add | \ | RMSNorm - | mlp + | MLP | dropout | / Add diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index 29e519c..e83eea2 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -316,7 +316,6 @@ class GLMBlock(torch.nn.Module): device=device, dtype=config.torch_dtype, ) - # MLP self.mlp = MLP(config, device=device) def forward(self, hidden_states, rotary_pos_emb):