Refine rwkv.

This commit is contained in:
Colin 2025-03-05 19:39:08 +08:00
parent 240858c030
commit 821e7206b8
4 changed files with 227 additions and 59 deletions

View File

@ -14,5 +14,4 @@ K、V 就是等同于Transformer的Key与Value。
TimeMix指的是过去信息x-1与当前信息x的混合。 xx = self.time_shift(x) - x 这个是典型的操作
nn.Embedding
RWKV_Tmix_x070

View File

@ -173,7 +173,6 @@ class RWKV_Tmix_x070(Module):
self.k_a = nn.Parameter(torch.empty(1, 1, C))
self.r_k = nn.Parameter(torch.empty(H, N))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(C, C, bias=False)
self.key = nn.Linear(C, C, bias=False)
self.value = nn.Linear(C, C, bias=False)
@ -181,64 +180,73 @@ class RWKV_Tmix_x070(Module):
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
def forward(self, x, v_first):
B, T, C = x.size()
H = self.n_head
xx = self.time_shift(x) - x
B, T, C = x.size() # seq_len
H = self.n_head # 12
xr = x + xx * self.x_r
xw = x + xx * self.x_w
xk = x + xx * self.x_k
xv = x + xx * self.x_v
xa = x + xx * self.x_a
xg = x + xx * self.x_g
shift = torch.zeros_like(x)
shift[:, 1:, :] = x[:, :-1, :]
r = self.receptance(xr)
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
k = self.key(xk)
v = self.value(xv)
xx = shift - x # time_shift [1, seq_len, 768] -> [1, seq_len, 768]
xr = x + xx * self.x_r # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xw = x + xx * self.x_w # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xk = x + xx * self.x_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xv = x + xx * self.x_v # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xa = x + xx * self.x_a # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xg = x + xx * self.x_g # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
r = self.receptance(xr) # [1, seq_len, 768] -> [1, seq_len, 768]
# [1, seq_len, 768] -> [1, seq_len, 768]
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # softplus clamp to (-inf, -0.5)
k = self.key(xk) # [1, seq_len, 768] -> [1, seq_len, 768]
v = self.value(xv) # [1, seq_len, 768] -> [1, seq_len, 768]
if self.layer_id == 0:
v_first = v # store the v of the first layer
else:
# -> [1, seq_len, 768]
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.g1) @ self.g2
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # -> [1, seq_len, 768] # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.g1) @ self.g2 # -> [1, seq_len, 768]
kk = k * self.k_k
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)
k = k * (1 + (a - 1) * self.k_a)
kk = k * self.k_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # -> [1, seq_len, 768]
k = k * (1 + (a - 1) * self.k_a) # -> [1, seq_len, 768]
def RWKV7_OP(r, w, k, v, a, b):
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
r = r.view(B, T, H, N).float()
k = k.view(B, T, H, N).float()
v = v.view(B, T, H, N).float()
a = a.view(B, T, H, N).float()
b = b.view(B, T, H, N).float()
w = torch.exp(-torch.exp(w.view(B, T, H, N).float()))
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float)
state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float)
B, T, C = r.size() # 768
H = C // HEAD_SIZE # 12
N = HEAD_SIZE # 64
r = r.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
k = k.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
v = v.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
a = a.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
b = b.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) # -> [1, seq_len, 12, 64]
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
for t in range(T):
kk = k[:, t, :].view(B, H, 1, N)
rr = r[:, t, :].view(B, H, N, 1)
vv = v[:, t, :].view(B, H, N, 1)
aa = a[:, t, :].view(B, H, N, 1)
bb = b[:, t, :].view(B, H, 1, N)
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
out[:, t, :] = (state @ rr).view(B, H, N)
kk = k[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64]
rr = r[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1]
vv = v[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1]
aa = a[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1]
bb = b[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64]
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk # -> [1, 12, 64, 64]
out[:, t, :] = (state @ rr).view(B, H, N) # -> [1, seq_len, 12, 64]
return out.view(B, T, C).to(dtype=DTYPE)
return out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768]
x = RWKV7_OP(r, w, k, v, -kk, kk * a)
x = RWKV7_OP(r, w, k, v, -kk, kk * a) # -> [1, seq_len, 768]
x = self.ln_x(x.view(B * T, C)).view(B, T, C)
x = self.ln_x(x.view(B * T, C)).view(B, T, C) # -> [1, seq_len, 768]
x = x + (
(r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(dim=-1, keepdim=True) * v.view(B, T, H, -1)
).view(B, T, C)
x = self.output(x * g)
xx = r.view(B, T, H, -1) * k.view(B, T, H, -1) # -> [1, seq_len, 12, 64]
xx = xx * self.r_k # -> [1, seq_len, 12, 64]
xx = xx.sum(dim=-1, keepdim=True) # -> [1, seq_len, 12, 1]
xx = xx * v.view(B, T, H, -1) # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64]
xx = xx.view(B, T, C) # -> [1, seq_len, 768]
x = x + xx # -> [1, seq_len, 768]
x = self.output(x * g) # -> [1, seq_len, 768]
return x, v_first
@ -252,7 +260,6 @@ class RWKV_CMix_x070(Module):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad():
self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd))
@ -261,11 +268,15 @@ class RWKV_CMix_x070(Module):
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
def forward(self, x):
xx = self.time_shift(x) - x
k = x + xx * self.x_k
k = torch.relu(self.key(k)) ** 2
return self.value(k)
shift = torch.zeros_like(x)
shift[:, 1:, :] = x[:, :-1, :]
xx = shift - x # time_shift -> [1, seq_len, 768]
k = x + xx * self.x_k # -> [1, seq_len, 768]
k = torch.relu(self.key(k)) ** 2 # -> [1, seq_len, 768]
return self.value(k) # -> [1, seq_len, 768]
########################################################################################################
@ -289,11 +300,12 @@ class Block(Module):
def forward(self, x, v_first):
if self.layer_id == 0:
x = self.ln0(x)
x = self.ln0(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β
ln = self.ln1(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β
xx, v_first = self.att(self.ln1(x), v_first)
x = x + xx
x = x + self.ffn(self.ln2(x))
xx, v_first = self.att(ln, v_first) # [1, seq_len, 768] -> [1, seq_len, 768]
x = x + xx # [1, seq_len, 768] -> [1, seq_len, 768]
x = x + self.ffn(self.ln2(x)) # [1, seq_len, 768] -> [1, seq_len, 768]
return x, v_first
@ -317,14 +329,14 @@ class RWKV(nn.Module):
def forward(self, idx):
x = self.emb(idx)
x = self.emb(idx) # [1, seq_len] -> [1, seq_len, 768]
v_first = torch.empty_like(x)
v_first = torch.empty_like(x) # -> [1, seq_len, 768]
for block in self.blocks:
x, v_first = block(x, v_first)
x = self.ln_out(x)
x = self.head(x)
x = self.ln_out(x) # [1, seq_len, 768] -> [1, seq_len, 768]
x = self.head(x) # [1, seq_len, 768] -> [1, seq_len, 65536]
return x
@ -346,6 +358,19 @@ with torch.no_grad():
input = tokenizer.encode(prompt)
print(f"\nInput:\n{input}")
# 中国的首都是在
# 北 [probability 4.99%]
# 中 [probability 4.22%]
# 这 [probability 3.38%]
# 上 [probability 2.74%]
# 东 [probability 2.28%]
# 台 [probability 2.23%]
# 南 [probability 1.86%]
# 广 [probability 1.83%]
# 华 [probability 1.63%]
# 河 [probability 1.47%]
out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
print(f"\nOutput:\n{out}")

View File

@ -0,0 +1,114 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
# RWKV-7 in numpy, by https://github.com/johanwind
import numpy as np
from torch import load as torch_load
layer_norm = lambda x, w, b: (x - x.mean()) / (x.var() + 1e-5) ** 0.5 * w + b
group_norm = (
lambda x, w, b: ((x - x.mean(axis=1, keepdims=1)) / (x.var(axis=1, keepdims=1) + 64e-5) ** 0.5).flatten() * w + b
)
sigmoid = lambda x: 1 / (1 + np.exp(-x))
def time_mixing(x, v0, last_x, S, params):
mr, mw, mk, mv, ma, mg, w_bias, r_k, Ww1, Ww2, Wa1, Wa2, a_bias, Wg1, Wg2 = params[:15]
k_k, k_a, Wr, Wk, Wv, Wo, ln_w, ln_b = params[-8:]
xr, xw, xk, xv, xa, xg = [x + m * (last_x - x) for m in [mr, mw, mk, mv, ma, mg]]
r = Wr @ xr
w = np.exp(-sigmoid(np.tanh(xw @ Ww1) @ Ww2 + w_bias) / np.e**0.5)
k = Wk @ xk
v = Wv @ xv
if v0 is None:
v0 = v
else:
Wv2, Wv1, v_bias = params[15:18]
v += (v0 - v) * sigmoid(xv @ Wv1 @ Wv2 + v_bias)
a = sigmoid(xa @ Wa1 @ Wa2 + a_bias)
g = sigmoid(xg @ Wg1) @ Wg2
kk = k * k_k
k += k * (a - 1) * k_a
r, w, k, v, kk, a, r_k = [i.reshape(N_HEAD, HEAD_SIZE, 1) for i in [r, w, k, v, kk, a, r_k]]
kk /= np.maximum(np.linalg.norm(kk, axis=1, keepdims=1), 1e-12)
S = S * w.mT - S @ kk * (kk * a).mT + v * k.mT
y = S @ r
y = group_norm(y, ln_w, ln_b)
y += ((r * k * r_k).sum(axis=1, keepdims=1) * v).flatten()
return Wo @ (y * g), v0, x, S
def channel_mixing(x, last_x, mix, Wk, Wv):
k = Wk @ (x + mix * (last_x - x))
v = Wv @ np.maximum(k, 0) ** 2
return v, x
def RWKV7(params, token, state):
x = params("emb")[0][token]
x = layer_norm(x, *params("blocks.0.ln0"))
v0 = None
for i in range(N_LAYER):
x_ = layer_norm(x, *params(f"blocks.{i}.ln1"))
dx, v0, state[0][i, 0], state[1][i] = time_mixing(
x_, v0, state[0][i, 0], state[1][i], params(f"blocks.{i}.att")
)
x = x + dx
x_ = layer_norm(x, *params(f"blocks.{i}.ln2"))
dx, state[0][i, 1] = channel_mixing(x_, state[0][i, 1], *params(f"blocks.{i}.ffn"))
x = x + dx
x = layer_norm(x, *params("ln_out"))
logits = params("head")[0] @ x
return logits, state
# Verification
# Available at https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth
MODEL_FILE = "/home/colin/.cache/modelscope/hub/Blink_DL/rwkv-7-world/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
N_LAYER = 24
N_EMBD = 1024
HEAD_SIZE = 64
N_HEAD = N_EMBD // HEAD_SIZE
if 1: # Reference implementation
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
# pip install rwkv
import os
os.environ["RWKV_V7_ON"] = "1"
from rwkv.utils import PIPELINE
from rwkv.model import RWKV as referenceRWKV
model = referenceRWKV(model=MODEL_FILE[:-4], strategy="cpu fp32")
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
tokens = pipeline.encode(context)
reference_logits, state = model.forward(tokens, None)
reference_logits = reference_logits.numpy()
weights = torch_load(MODEL_FILE, map_location="cpu", weights_only=True)
weights = {k: v.squeeze().float().numpy() for k, v in weights.items()}
params = lambda prefix: [weights[key] for key in weights.keys() if key.startswith(prefix)]
state = (
np.zeros((N_LAYER, 2, N_EMBD), dtype=np.float32),
np.zeros((N_LAYER, N_HEAD, HEAD_SIZE, HEAD_SIZE), dtype=np.float32),
)
for token in tokens:
minimal_logits, state = RWKV7(params, token, state)
print("Deviation from official rwkv:", max(abs(minimal_logits - reference_logits)) / reference_logits.std())

View File

@ -1,5 +1,35 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.nn as nn
# 假设输入是一个 batch 的序列数据,形状为 (batch_size, seq_len, hidden_dim)
batch_size, seq_len, hidden_dim = 2, 2, 4
input_tensor = torch.randn(batch_size, seq_len, hidden_dim)
# 定义 LayerNorm 层
layer_norm1 = nn.LayerNorm(hidden_dim)
layer_norm2 = nn.LayerNorm(hidden_dim)
# 应用 LayerNorm
output = layer_norm1(input_tensor)
print(input_tensor.numpy())
print("\n")
print("\n")
print(output.detach().numpy())
output = layer_norm2(output)
print("\n")
print("\n")
print(output.detach().numpy())
x1 = torch.empty((1, 7, 768), dtype=float)
time_shift = nn.ZeroPad2d((1, 1, -1, 1))
xx = time_shift(x1)
x1 = torch.tensor([[1, 2]], dtype=float)
x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)