Add rwkv flow graph.

This commit is contained in:
Colin 2025-03-06 23:22:50 +08:00
parent 821e7206b8
commit 251ea7f004
3 changed files with 1102 additions and 55 deletions

1026
rwkv/RWKV-v7/RWKV.drawio Normal file

File diff suppressed because it is too large Load Diff

BIN
rwkv/RWKV-v7/RWKV.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 475 KiB

View File

@ -40,7 +40,7 @@ args.vocab_size = 65536
DTYPE = torch.half # better
args.head_size_a = 64 # don't change
HEAD_SIZE = args.head_size_a
HS = args.head_size_a
########################################################################################################
# RWKV Tokenizer (slow version)
@ -144,7 +144,7 @@ class RWKV_Tmix_x070(Module):
assert args.dim_att % self.n_head == 0
H = self.n_head
N = self.head_size
HS = self.head_size
C = args.n_embd
self.x_r = nn.Parameter(torch.empty(1, 1, C))
@ -171,7 +171,7 @@ class RWKV_Tmix_x070(Module):
self.k_k = nn.Parameter(torch.empty(1, 1, C))
self.k_a = nn.Parameter(torch.empty(1, 1, C))
self.r_k = nn.Parameter(torch.empty(H, N))
self.r_k = nn.Parameter(torch.empty(H, HS))
self.receptance = nn.Linear(C, C, bias=False)
self.key = nn.Linear(C, C, bias=False)
@ -183,10 +183,9 @@ class RWKV_Tmix_x070(Module):
B, T, C = x.size() # seq_len
H = self.n_head # 12
shift = torch.zeros_like(x)
shift[:, 1:, :] = x[:, :-1, :]
xx = shift - x # time_shift [1, seq_len, 768] -> [1, seq_len, 768]
xx = torch.zeros_like(x) # time_shift [1, seq_len, 768] -> [1, seq_len, 768]
xx[:, 0, :] = -x[:, 0, :]
xx[:, 1:, :] = x[:, :-1, :] - x[:, 1:, :]
xr = x + xx * self.x_r # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xw = x + xx * self.x_w # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
@ -195,48 +194,71 @@ class RWKV_Tmix_x070(Module):
xa = x + xx * self.x_a # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xg = x + xx * self.x_g # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
r = self.receptance(xr) # [1, seq_len, 768] -> [1, seq_len, 768]
# [1, seq_len, 768] -> [1, seq_len, 768]
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # softplus clamp to (-inf, -0.5)
k = self.key(xk) # [1, seq_len, 768] -> [1, seq_len, 768]
v = self.value(xv) # [1, seq_len, 768] -> [1, seq_len, 768]
r = self.receptance(xr) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
xw = torch.tanh(xw @ self.w1) # -> [1, seq_len, 64]
xw = xw @ self.w2 + self.w0 # -> [1, seq_len, 768]
xw = -F.softplus(-xw) # 函数的输出范围为 [0,∞)
w = xw - 0.5 # -> [1, seq_len, 768]
k = self.key(xk) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
v = self.value(xv) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
if self.layer_id == 0:
v_first = v # store the v of the first layer
else:
# -> [1, seq_len, 768]
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # -> [1, seq_len, 768] # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.g1) @ self.g2 # -> [1, seq_len, 768]
xv = (xv @ self.v1) @ self.v2
xv = xv + self.v0
xv = torch.sigmoid(xv)
v = v + (v_first - v) * xv # add value residual # -> [1, seq_len, 768]
xa = (xa @ self.a1) @ self.a2
xa = xa + self.a0
a = torch.sigmoid(xa) # -> [1, seq_len, 768]
xg = xg @ self.g1
xg = torch.sigmoid(xg)
g = xg @ self.g2 # -> [1, seq_len, 768]
kk = k * self.k_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # -> [1, seq_len, 768]
k = k * (1 + (a - 1) * self.k_a) # -> [1, seq_len, 768]
def RWKV7_OP(r, w, k, v, a, b):
B, T, C = r.size() # 768
H = C // HEAD_SIZE # 12
N = HEAD_SIZE # 64
r = r.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
k = k.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
v = v.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
a = a.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
b = b.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) # -> [1, seq_len, 12, 64]
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
# start op
a_op = -kk
b_op = kk * a
for t in range(T):
kk = k[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64]
rr = r[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1]
vv = v[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1]
aa = a[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1]
bb = b[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64]
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk # -> [1, 12, 64, 64]
out[:, t, :] = (state @ rr).view(B, H, N) # -> [1, seq_len, 12, 64]
B, T, C = r.size() # 768
H = C // HS # 12
return out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768]
r_op = r.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
k_op = k.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64]
v_op = v.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
a_op = a_op.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
b_op = b_op.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64]
x = RWKV7_OP(r, w, k, v, -kk, kk * a) # -> [1, seq_len, 768]
w_op = w.view(B, T, H, HS).float()
w_op = torch.exp(-torch.exp(w_op)) # -> [1, seq_len, 12, 64]
w_op = w_op.view(B, T, H, 1, HS) # -> [1, seq_len, 12, 1, 64]
out = torch.zeros((B, T, H, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
state = torch.zeros((B, H, HS, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
for t in range(T):
rr_op = r_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
kk_op = k_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64]
vv_op = v_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
aa_op = a_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
bb_op = b_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64]
ww_op = w_op[:, t, :] # [1, seq_len, 12, 64] -> [1, 12, 1, 64]
state = state * ww_op + state @ aa_op @ bb_op + vv_op @ kk_op # -> [1, 12, 64, 64]
out[:, t, :] = (state @ rr_op).view(B, H, HS) # -> [1, seq_len, 12, 64]
x = out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768]
# end op
x = self.ln_x(x.view(B * T, C)).view(B, T, C) # -> [1, seq_len, 768]
@ -246,7 +268,7 @@ class RWKV_Tmix_x070(Module):
xx = xx * v.view(B, T, H, -1) # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64]
xx = xx.view(B, T, C) # -> [1, seq_len, 768]
x = x + xx # -> [1, seq_len, 768]
x = self.output(x * g) # -> [1, seq_len, 768]
x = self.output(x * g) # Linear -> [1, seq_len, 768]
return x, v_first
@ -275,8 +297,8 @@ class RWKV_CMix_x070(Module):
xx = shift - x # time_shift -> [1, seq_len, 768]
k = x + xx * self.x_k # -> [1, seq_len, 768]
k = torch.relu(self.key(k)) ** 2 # -> [1, seq_len, 768]
return self.value(k) # -> [1, seq_len, 768]
k = torch.relu(self.key(k)) ** 2 # Linear -> [1, seq_len, 768]
return self.value(k) # Linear -> [1, seq_len, 768]
########################################################################################################
@ -300,8 +322,8 @@ class Block(Module):
def forward(self, x, v_first):
if self.layer_id == 0:
x = self.ln0(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β
ln = self.ln1(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β
x = self.ln0(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β
ln = self.ln1(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β
xx, v_first = self.att(ln, v_first) # [1, seq_len, 768] -> [1, seq_len, 768]
x = x + xx # [1, seq_len, 768] -> [1, seq_len, 768]
@ -358,18 +380,17 @@ with torch.no_grad():
input = tokenizer.encode(prompt)
print(f"\nInput:\n{input}")
# 中国的首都是在
# 北 [probability 4.99%]
# 中 [probability 4.22%]
# 这 [probability 3.38%]
# 上 [probability 2.74%]
# 东 [probability 2.28%]
# 台 [probability 2.23%]
# 南 [probability 1.86%]
# 广 [probability 1.83%]
# 华 [probability 1.63%]
# 河 [probability 1.47%]
# 中国的首都是在
# 北 [probability 4.99%]
# 中 [probability 4.22%]
# 这 [probability 3.38%]
# 上 [probability 2.74%]
# 东 [probability 2.28%]
# 台 [probability 2.23%]
# 南 [probability 1.86%]
# 广 [probability 1.83%]
# 华 [probability 1.63%]
# 河 [probability 1.47%]
out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
print(f"\nOutput:\n{out}")