diff --git a/Readme.md b/Readme.md index 43ced71..4536cfd 100644 --- a/Readme.md +++ b/Readme.md @@ -5,9 +5,9 @@ ``` - query + query -> "你好" | - tokenizer -> input_ids + tokenizer -> [6] | rotary_pos_emb embedding -> [1, 6, 4096] \ / diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index dc846b0..d4abdff 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -213,18 +213,9 @@ class SelfAttention(torch.nn.Module): class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - def __init__(self, config: ChatGLMConfig, device=None): super(MLP, self).__init__() - self.add_bias = config.add_bias_linear - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf self.dense_h_to_4h = nn.Linear( config.hidden_size, @@ -233,13 +224,10 @@ class MLP(torch.nn.Module): device=device, dtype=config.torch_dtype, ) - def swiglu(x): x = torch.chunk(x, 2, dim=-1) return F.silu(x[0]) * x[1] - self.activation_func = swiglu - self.dense_4h_to_h = nn.Linear( config.ffn_hidden_size, config.hidden_size, @@ -249,10 +237,8 @@ class MLP(torch.nn.Module): ) def forward(self, hidden_states): - # [s, b, 4hp] intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] output = self.dense_4h_to_h(intermediate_parallel) return output diff --git a/graph.md b/graph.md new file mode 100644 index 0000000..5896abe --- /dev/null +++ b/graph.md @@ -0,0 +1,88 @@ +## data flow + +``` + query -> "你好" + ┃ + tokenizer -> input_ids [6] + ┃ + rotary_pos_emb embedding -> [1, 6, 4096] + ╲ ╱ + GLMBlock x 28 -> [6, 1, 4096] <━━━┓ + RMSNorm -> [6, 1, 4096] ┃ final_layernorm + [-1:] -> [1, 1, 4096] ┃ + Linear -> [1, 1, 65024] ┃ output_layer 4096->65024 + softmax -> [1, 65024] ┃ + multinomial -> [1] ┃ + cat([input_ids, next_tokens]) ━━━┛ + ↓ + tokenizer.decode( ) + +# GLMBlock + + input + ╱ ╲ + ╱ RMSNorm hidden_states -> [6, 1, 4096] + ┃ ┋ ╱ ╲ + ┃ ┋ ┃ pow(2) -> [6, 1, 4096] + ┃ ┋ ┃ ┃ + ┃ ┋ ┃ mean -> [6, 1, 1] + ┃ ┋ ┃ ↓ + ┃ ┋ ┃ rsqrt( + eps) -> [6, 1, 1] + ┃ ┋ ╲ ╱ + ┃ ┋ mul -> [6, 1, 4096] + ┃ ┋ ╲ weight -> [4096] + ┃ ┋ ╲ ╱ + ┃ RMSNorm mul -> [6, 1, 4096] + ┃ ╲ + ┃ SelfAttention x -> [6, 1, 4096] + ┃ ┋ ┃ + ┃ ┋ Linear -> [6, 1, 4608] 4096->4608 + ┃ ┋ ╱ ┃ ╲ + ┃ ┋ q k v [6, 1, 32, 128] [6, 1, 2, 128] [6, 1, 2, 128] + ┃ ┋ ╱ ┃ ╲ + ┃ ┋ pos_emb pos_emb ╲ -> cat( x0*y0-x1*y1, x1*y0-x0*y1, x, y) + ┃ ┋ ┃ ┃ ┃ + ┃ ┋ ┃ expand expand -> [6, 1, 32, 128] [6, 1, 32, 128] + ┃ ┋ permute permute permute -> [1, 32, 6, 128] [1, 32, 6, 128] [1, 32, 6, 128] + ┃ ┋ ╲ ╱ ┃ + ┃ ┋ ┏---- matmul ┃ -> [1, 32, 6, 128] [1, 32, 128, 6] -> [1, 32, 6, 6] + ┃ ┋ ┃ add(mask) ╱ -> [1, 32, 6, 6] + ┃ ┋ attention┃ softmax ╱ -> [1, 32, 6, 6] dim:-1 + ┃ ┋ ┃ ╲ ╱ + ┃ ┋ ┗---- matmul -> [1, 32, 6, 6] [1, 32, 6, 128] -> [1, 32, 6, 128] -> [6, 1, 4096] + ┃ SelfAttention Linear -> [6, 1, 4096] 4096->4096 + ┃ ╱ + ┃ dropout + ╲ ╱ + Add + ╱ ╲ + ┃ RMSNorm hidden_states -> [6, 1, 4096] + ┃ ┋ ╱ ╲ + ┃ ┋ ┃ pow(2) -> [6, 1, 4096] + ┃ ┋ ┃ ┃ + ┃ ┋ ┃ mean -> [6, 1, 1] + ┃ ┋ ┃ ↓ + ┃ ┋ ┃ rsqrt( + eps) -> [6, 1, 1] + ┃ ┋ ╲ ╱ + ┃ ┋ mul -> [6, 1, 4096] + ┃ ┋ ╲ weight -> [4096] + ┃ ┋ ╲ ╱ + ┃ RMSNorm mul -> [6, 1, 4096] + ┃ ╱ + ┃ mlp ╱ + ┃ ┋ Linear -> [6, 1, 27392] 4096->27392 + ┃ ┋ ╱ ╲ + ┃ ┋ chunk1 chunk0 -> [6, 1, 13696] + ┃ ┋ ┃ ┃ ╲ + ┃ ┋ ┃ ┃ sigmoid + ┃ ┋ ┃ ┃ ╱ + ┃ ┋ ┃ mul + ┃ ┋ ╲ ╱ + ┃ ┋ mul -> [6, 1, 13696] + ┃ mlp Linear -> [6, 1, 4096] 13696->4096 + ┃ ╱ + ┃ dropout + ┃ ╱ + Add + +``` diff --git a/tensor.py b/tensor.py index 24633d4..eef0b44 100644 --- a/tensor.py +++ b/tensor.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F x = torch.tensor([[1, 2], [3, 4]]) @@ -47,3 +48,11 @@ print(torch.cat((x, x), 1)) # So if A and B are of shape (3, 4): # torch.cat([A, B], dim=0) will be of shape (6, 4) # torch.stack([A, B], dim=0) will be of shape (2, 3, 4) + +x = torch.ones([1, 32, 6, 128]) +y = torch.ones([1, 32, 128, 6]) +z = torch.ones([1, 32, 6, 128]) +att = torch.matmul(x, y) +mm = torch.matmul(att, z) +print(mm.shape) +