Update graph.md.
This commit is contained in:
parent
b3ef30aa1a
commit
0fa38b7815
|
@ -5,9 +5,9 @@
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
query
|
query -> "你好"
|
||||||
|
|
|
|
||||||
tokenizer -> input_ids
|
tokenizer -> [6]
|
||||||
|
|
|
|
||||||
rotary_pos_emb embedding -> [1, 6, 4096]
|
rotary_pos_emb embedding -> [1, 6, 4096]
|
||||||
\ /
|
\ /
|
||||||
|
|
|
@ -213,18 +213,9 @@ class SelfAttention(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
class MLP(torch.nn.Module):
|
||||||
"""MLP.
|
|
||||||
|
|
||||||
MLP will take the input with h hidden state, project it to 4*h
|
|
||||||
hidden dimension, perform nonlinear transformation, and project the
|
|
||||||
state back into h hidden dimension.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: ChatGLMConfig, device=None):
|
def __init__(self, config: ChatGLMConfig, device=None):
|
||||||
super(MLP, self).__init__()
|
super(MLP, self).__init__()
|
||||||
|
|
||||||
self.add_bias = config.add_bias_linear
|
self.add_bias = config.add_bias_linear
|
||||||
|
|
||||||
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
||||||
self.dense_h_to_4h = nn.Linear(
|
self.dense_h_to_4h = nn.Linear(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
@ -233,13 +224,10 @@ class MLP(torch.nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=config.torch_dtype,
|
dtype=config.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def swiglu(x):
|
def swiglu(x):
|
||||||
x = torch.chunk(x, 2, dim=-1)
|
x = torch.chunk(x, 2, dim=-1)
|
||||||
return F.silu(x[0]) * x[1]
|
return F.silu(x[0]) * x[1]
|
||||||
|
|
||||||
self.activation_func = swiglu
|
self.activation_func = swiglu
|
||||||
|
|
||||||
self.dense_4h_to_h = nn.Linear(
|
self.dense_4h_to_h = nn.Linear(
|
||||||
config.ffn_hidden_size,
|
config.ffn_hidden_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
@ -249,10 +237,8 @@ class MLP(torch.nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# [s, b, 4hp]
|
|
||||||
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
||||||
intermediate_parallel = self.activation_func(intermediate_parallel)
|
intermediate_parallel = self.activation_func(intermediate_parallel)
|
||||||
# [s, b, h]
|
|
||||||
output = self.dense_4h_to_h(intermediate_parallel)
|
output = self.dense_4h_to_h(intermediate_parallel)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
## data flow
|
||||||
|
|
||||||
|
```
|
||||||
|
query -> "你好"
|
||||||
|
┃
|
||||||
|
tokenizer -> input_ids [6]
|
||||||
|
┃
|
||||||
|
rotary_pos_emb embedding -> [1, 6, 4096]
|
||||||
|
╲ ╱
|
||||||
|
GLMBlock x 28 -> [6, 1, 4096] <━━━┓
|
||||||
|
RMSNorm -> [6, 1, 4096] ┃ final_layernorm
|
||||||
|
[-1:] -> [1, 1, 4096] ┃
|
||||||
|
Linear -> [1, 1, 65024] ┃ output_layer 4096->65024
|
||||||
|
softmax -> [1, 65024] ┃
|
||||||
|
multinomial -> [1] ┃
|
||||||
|
cat([input_ids, next_tokens]) ━━━┛
|
||||||
|
↓
|
||||||
|
tokenizer.decode( )
|
||||||
|
|
||||||
|
# GLMBlock
|
||||||
|
|
||||||
|
input
|
||||||
|
╱ ╲
|
||||||
|
╱ RMSNorm hidden_states -> [6, 1, 4096]
|
||||||
|
┃ ┋ ╱ ╲
|
||||||
|
┃ ┋ ┃ pow(2) -> [6, 1, 4096]
|
||||||
|
┃ ┋ ┃ ┃
|
||||||
|
┃ ┋ ┃ mean -> [6, 1, 1]
|
||||||
|
┃ ┋ ┃ ↓
|
||||||
|
┃ ┋ ┃ rsqrt( + eps) -> [6, 1, 1]
|
||||||
|
┃ ┋ ╲ ╱
|
||||||
|
┃ ┋ mul -> [6, 1, 4096]
|
||||||
|
┃ ┋ ╲ weight -> [4096]
|
||||||
|
┃ ┋ ╲ ╱
|
||||||
|
┃ RMSNorm mul -> [6, 1, 4096]
|
||||||
|
┃ ╲
|
||||||
|
┃ SelfAttention x -> [6, 1, 4096]
|
||||||
|
┃ ┋ ┃
|
||||||
|
┃ ┋ Linear -> [6, 1, 4608] 4096->4608
|
||||||
|
┃ ┋ ╱ ┃ ╲
|
||||||
|
┃ ┋ q k v [6, 1, 32, 128] [6, 1, 2, 128] [6, 1, 2, 128]
|
||||||
|
┃ ┋ ╱ ┃ ╲
|
||||||
|
┃ ┋ pos_emb pos_emb ╲ -> cat( x0*y0-x1*y1, x1*y0-x0*y1, x, y)
|
||||||
|
┃ ┋ ┃ ┃ ┃
|
||||||
|
┃ ┋ ┃ expand expand -> [6, 1, 32, 128] [6, 1, 32, 128]
|
||||||
|
┃ ┋ permute permute permute -> [1, 32, 6, 128] [1, 32, 6, 128] [1, 32, 6, 128]
|
||||||
|
┃ ┋ ╲ ╱ ┃
|
||||||
|
┃ ┋ ┏---- matmul ┃ -> [1, 32, 6, 128] [1, 32, 128, 6] -> [1, 32, 6, 6]
|
||||||
|
┃ ┋ ┃ add(mask) ╱ -> [1, 32, 6, 6]
|
||||||
|
┃ ┋ attention┃ softmax ╱ -> [1, 32, 6, 6] dim:-1
|
||||||
|
┃ ┋ ┃ ╲ ╱
|
||||||
|
┃ ┋ ┗---- matmul -> [1, 32, 6, 6] [1, 32, 6, 128] -> [1, 32, 6, 128] -> [6, 1, 4096]
|
||||||
|
┃ SelfAttention Linear -> [6, 1, 4096] 4096->4096
|
||||||
|
┃ ╱
|
||||||
|
┃ dropout
|
||||||
|
╲ ╱
|
||||||
|
Add
|
||||||
|
╱ ╲
|
||||||
|
┃ RMSNorm hidden_states -> [6, 1, 4096]
|
||||||
|
┃ ┋ ╱ ╲
|
||||||
|
┃ ┋ ┃ pow(2) -> [6, 1, 4096]
|
||||||
|
┃ ┋ ┃ ┃
|
||||||
|
┃ ┋ ┃ mean -> [6, 1, 1]
|
||||||
|
┃ ┋ ┃ ↓
|
||||||
|
┃ ┋ ┃ rsqrt( + eps) -> [6, 1, 1]
|
||||||
|
┃ ┋ ╲ ╱
|
||||||
|
┃ ┋ mul -> [6, 1, 4096]
|
||||||
|
┃ ┋ ╲ weight -> [4096]
|
||||||
|
┃ ┋ ╲ ╱
|
||||||
|
┃ RMSNorm mul -> [6, 1, 4096]
|
||||||
|
┃ ╱
|
||||||
|
┃ mlp ╱
|
||||||
|
┃ ┋ Linear -> [6, 1, 27392] 4096->27392
|
||||||
|
┃ ┋ ╱ ╲
|
||||||
|
┃ ┋ chunk1 chunk0 -> [6, 1, 13696]
|
||||||
|
┃ ┋ ┃ ┃ ╲
|
||||||
|
┃ ┋ ┃ ┃ sigmoid
|
||||||
|
┃ ┋ ┃ ┃ ╱
|
||||||
|
┃ ┋ ┃ mul
|
||||||
|
┃ ┋ ╲ ╱
|
||||||
|
┃ ┋ mul -> [6, 1, 13696]
|
||||||
|
┃ mlp Linear -> [6, 1, 4096] 13696->4096
|
||||||
|
┃ ╱
|
||||||
|
┃ dropout
|
||||||
|
┃ ╱
|
||||||
|
Add
|
||||||
|
|
||||||
|
```
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
x = torch.tensor([[1, 2], [3, 4]])
|
x = torch.tensor([[1, 2], [3, 4]])
|
||||||
|
|
||||||
|
@ -47,3 +48,11 @@ print(torch.cat((x, x), 1))
|
||||||
# So if A and B are of shape (3, 4):
|
# So if A and B are of shape (3, 4):
|
||||||
# torch.cat([A, B], dim=0) will be of shape (6, 4)
|
# torch.cat([A, B], dim=0) will be of shape (6, 4)
|
||||||
# torch.stack([A, B], dim=0) will be of shape (2, 3, 4)
|
# torch.stack([A, B], dim=0) will be of shape (2, 3, 4)
|
||||||
|
|
||||||
|
x = torch.ones([1, 32, 6, 128])
|
||||||
|
y = torch.ones([1, 32, 128, 6])
|
||||||
|
z = torch.ones([1, 32, 6, 128])
|
||||||
|
att = torch.matmul(x, y)
|
||||||
|
mm = torch.matmul(att, z)
|
||||||
|
print(mm.shape)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue