From 821e7206b8acd22253545fe0be532542a8ed0f5e Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 5 Mar 2025 19:39:08 +0800 Subject: [PATCH] Refine rwkv. --- rwkv/RWKV-v7/model.md | 3 +- rwkv/RWKV-v7/rwkv_v7_demo.py | 139 ++++++++++++++++++++-------------- rwkv/RWKV-v7/rwkv_v7_numpy.py | 114 ++++++++++++++++++++++++++++ test/tensor.py | 30 ++++++++ 4 files changed, 227 insertions(+), 59 deletions(-) create mode 100644 rwkv/RWKV-v7/rwkv_v7_numpy.py diff --git a/rwkv/RWKV-v7/model.md b/rwkv/RWKV-v7/model.md index 8e52d89..f0b8525 100644 --- a/rwkv/RWKV-v7/model.md +++ b/rwkv/RWKV-v7/model.md @@ -14,5 +14,4 @@ K、V 就是等同于Transformer的Key与Value。 TimeMix,指的是过去信息x-1与当前信息x的混合。 xx = self.time_shift(x) - x 这个是典型的操作 - -nn.Embedding \ No newline at end of file +RWKV_Tmix_x070 \ No newline at end of file diff --git a/rwkv/RWKV-v7/rwkv_v7_demo.py b/rwkv/RWKV-v7/rwkv_v7_demo.py index f8f2c53..ea2524d 100644 --- a/rwkv/RWKV-v7/rwkv_v7_demo.py +++ b/rwkv/RWKV-v7/rwkv_v7_demo.py @@ -173,7 +173,6 @@ class RWKV_Tmix_x070(Module): self.k_a = nn.Parameter(torch.empty(1, 1, C)) self.r_k = nn.Parameter(torch.empty(H, N)) - self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.receptance = nn.Linear(C, C, bias=False) self.key = nn.Linear(C, C, bias=False) self.value = nn.Linear(C, C, bias=False) @@ -181,64 +180,73 @@ class RWKV_Tmix_x070(Module): self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!! def forward(self, x, v_first): - B, T, C = x.size() - H = self.n_head - xx = self.time_shift(x) - x + B, T, C = x.size() # seq_len + H = self.n_head # 12 - xr = x + xx * self.x_r - xw = x + xx * self.x_w - xk = x + xx * self.x_k - xv = x + xx * self.x_v - xa = x + xx * self.x_a - xg = x + xx * self.x_g + shift = torch.zeros_like(x) + shift[:, 1:, :] = x[:, :-1, :] - r = self.receptance(xr) - w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5) - k = self.key(xk) - v = self.value(xv) + xx = shift - x # time_shift [1, seq_len, 768] -> [1, seq_len, 768] + + xr = x + xx * self.x_r # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] + xw = x + xx * self.x_w # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] + xk = x + xx * self.x_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] + xv = x + xx * self.x_v # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] + xa = x + xx * self.x_a # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] + xg = x + xx * self.x_g # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] + + r = self.receptance(xr) # [1, seq_len, 768] -> [1, seq_len, 768] + # [1, seq_len, 768] -> [1, seq_len, 768] + w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # softplus clamp to (-inf, -0.5) + k = self.key(xk) # [1, seq_len, 768] -> [1, seq_len, 768] + v = self.value(xv) # [1, seq_len, 768] -> [1, seq_len, 768] if self.layer_id == 0: v_first = v # store the v of the first layer else: + # -> [1, seq_len, 768] v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual - a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate" - g = torch.sigmoid(xg @ self.g1) @ self.g2 + a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # -> [1, seq_len, 768] # a is "in-context learning rate" + g = torch.sigmoid(xg @ self.g1) @ self.g2 # -> [1, seq_len, 768] - kk = k * self.k_k - kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) - k = k * (1 + (a - 1) * self.k_a) + kk = k * self.k_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] + kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # -> [1, seq_len, 768] + k = k * (1 + (a - 1) * self.k_a) # -> [1, seq_len, 768] def RWKV7_OP(r, w, k, v, a, b): - B, T, C = r.size() - H = C // HEAD_SIZE - N = HEAD_SIZE - r = r.view(B, T, H, N).float() - k = k.view(B, T, H, N).float() - v = v.view(B, T, H, N).float() - a = a.view(B, T, H, N).float() - b = b.view(B, T, H, N).float() - w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) - out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) - state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float) + B, T, C = r.size() # 768 + H = C // HEAD_SIZE # 12 + N = HEAD_SIZE # 64 + r = r.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] + k = k.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] + v = v.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] + a = a.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] + b = b.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] + w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) # -> [1, seq_len, 12, 64] + out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64] + state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64] for t in range(T): - kk = k[:, t, :].view(B, H, 1, N) - rr = r[:, t, :].view(B, H, N, 1) - vv = v[:, t, :].view(B, H, N, 1) - aa = a[:, t, :].view(B, H, N, 1) - bb = b[:, t, :].view(B, H, 1, N) - state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk - out[:, t, :] = (state @ rr).view(B, H, N) + kk = k[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64] + rr = r[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1] + vv = v[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1] + aa = a[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1] + bb = b[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64] + state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk # -> [1, 12, 64, 64] + out[:, t, :] = (state @ rr).view(B, H, N) # -> [1, seq_len, 12, 64] - return out.view(B, T, C).to(dtype=DTYPE) + return out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768] - x = RWKV7_OP(r, w, k, v, -kk, kk * a) + x = RWKV7_OP(r, w, k, v, -kk, kk * a) # -> [1, seq_len, 768] - x = self.ln_x(x.view(B * T, C)).view(B, T, C) + x = self.ln_x(x.view(B * T, C)).view(B, T, C) # -> [1, seq_len, 768] - x = x + ( - (r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(dim=-1, keepdim=True) * v.view(B, T, H, -1) - ).view(B, T, C) - x = self.output(x * g) + xx = r.view(B, T, H, -1) * k.view(B, T, H, -1) # -> [1, seq_len, 12, 64] + xx = xx * self.r_k # -> [1, seq_len, 12, 64] + xx = xx.sum(dim=-1, keepdim=True) # -> [1, seq_len, 12, 1] + xx = xx * v.view(B, T, H, -1) # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64] + xx = xx.view(B, T, C) # -> [1, seq_len, 768] + x = x + xx # -> [1, seq_len, 768] + x = self.output(x * g) # -> [1, seq_len, 768] return x, v_first @@ -252,7 +260,6 @@ class RWKV_CMix_x070(Module): super().__init__() self.args = args self.layer_id = layer_id - self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd)) @@ -261,11 +268,15 @@ class RWKV_CMix_x070(Module): self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) def forward(self, x): - xx = self.time_shift(x) - x - k = x + xx * self.x_k - k = torch.relu(self.key(k)) ** 2 - return self.value(k) + shift = torch.zeros_like(x) + shift[:, 1:, :] = x[:, :-1, :] + + xx = shift - x # time_shift -> [1, seq_len, 768] + + k = x + xx * self.x_k # -> [1, seq_len, 768] + k = torch.relu(self.key(k)) ** 2 # -> [1, seq_len, 768] + return self.value(k) # -> [1, seq_len, 768] ######################################################################################################## @@ -289,11 +300,12 @@ class Block(Module): def forward(self, x, v_first): if self.layer_id == 0: - x = self.ln0(x) + x = self.ln0(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β + ln = self.ln1(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β - xx, v_first = self.att(self.ln1(x), v_first) - x = x + xx - x = x + self.ffn(self.ln2(x)) + xx, v_first = self.att(ln, v_first) # [1, seq_len, 768] -> [1, seq_len, 768] + x = x + xx # [1, seq_len, 768] -> [1, seq_len, 768] + x = x + self.ffn(self.ln2(x)) # [1, seq_len, 768] -> [1, seq_len, 768] return x, v_first @@ -317,14 +329,14 @@ class RWKV(nn.Module): def forward(self, idx): - x = self.emb(idx) + x = self.emb(idx) # [1, seq_len] -> [1, seq_len, 768] - v_first = torch.empty_like(x) + v_first = torch.empty_like(x) # -> [1, seq_len, 768] for block in self.blocks: x, v_first = block(x, v_first) - x = self.ln_out(x) - x = self.head(x) + x = self.ln_out(x) # [1, seq_len, 768] -> [1, seq_len, 768] + x = self.head(x) # [1, seq_len, 768] -> [1, seq_len, 65536] return x @@ -346,6 +358,19 @@ with torch.no_grad(): input = tokenizer.encode(prompt) print(f"\nInput:\n{input}") +# 中国的首都是在 +# 北 [probability 4.99%] +# 中 [probability 4.22%] +# 这 [probability 3.38%] +# 上 [probability 2.74%] +# 东 [probability 2.28%] +# 台 [probability 2.23%] +# 南 [probability 1.86%] +# 广 [probability 1.83%] +# 华 [probability 1.63%] +# 河 [probability 1.47%] + + out = model.forward(torch.tensor(input).reshape(1, -1).cuda()) print(f"\nOutput:\n{out}") diff --git a/rwkv/RWKV-v7/rwkv_v7_numpy.py b/rwkv/RWKV-v7/rwkv_v7_numpy.py new file mode 100644 index 0000000..4879e22 --- /dev/null +++ b/rwkv/RWKV-v7/rwkv_v7_numpy.py @@ -0,0 +1,114 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## +# RWKV-7 in numpy, by https://github.com/johanwind + +import numpy as np +from torch import load as torch_load + +layer_norm = lambda x, w, b: (x - x.mean()) / (x.var() + 1e-5) ** 0.5 * w + b +group_norm = ( + lambda x, w, b: ((x - x.mean(axis=1, keepdims=1)) / (x.var(axis=1, keepdims=1) + 64e-5) ** 0.5).flatten() * w + b +) +sigmoid = lambda x: 1 / (1 + np.exp(-x)) + + +def time_mixing(x, v0, last_x, S, params): + mr, mw, mk, mv, ma, mg, w_bias, r_k, Ww1, Ww2, Wa1, Wa2, a_bias, Wg1, Wg2 = params[:15] + k_k, k_a, Wr, Wk, Wv, Wo, ln_w, ln_b = params[-8:] + + xr, xw, xk, xv, xa, xg = [x + m * (last_x - x) for m in [mr, mw, mk, mv, ma, mg]] + + r = Wr @ xr + w = np.exp(-sigmoid(np.tanh(xw @ Ww1) @ Ww2 + w_bias) / np.e**0.5) + k = Wk @ xk + v = Wv @ xv + if v0 is None: + v0 = v + else: + Wv2, Wv1, v_bias = params[15:18] + v += (v0 - v) * sigmoid(xv @ Wv1 @ Wv2 + v_bias) + a = sigmoid(xa @ Wa1 @ Wa2 + a_bias) + g = sigmoid(xg @ Wg1) @ Wg2 + kk = k * k_k + k += k * (a - 1) * k_a + + r, w, k, v, kk, a, r_k = [i.reshape(N_HEAD, HEAD_SIZE, 1) for i in [r, w, k, v, kk, a, r_k]] + kk /= np.maximum(np.linalg.norm(kk, axis=1, keepdims=1), 1e-12) + + S = S * w.mT - S @ kk * (kk * a).mT + v * k.mT + y = S @ r + + y = group_norm(y, ln_w, ln_b) + y += ((r * k * r_k).sum(axis=1, keepdims=1) * v).flatten() + return Wo @ (y * g), v0, x, S + + +def channel_mixing(x, last_x, mix, Wk, Wv): + k = Wk @ (x + mix * (last_x - x)) + v = Wv @ np.maximum(k, 0) ** 2 + return v, x + + +def RWKV7(params, token, state): + x = params("emb")[0][token] + x = layer_norm(x, *params("blocks.0.ln0")) + + v0 = None + for i in range(N_LAYER): + x_ = layer_norm(x, *params(f"blocks.{i}.ln1")) + dx, v0, state[0][i, 0], state[1][i] = time_mixing( + x_, v0, state[0][i, 0], state[1][i], params(f"blocks.{i}.att") + ) + x = x + dx + + x_ = layer_norm(x, *params(f"blocks.{i}.ln2")) + dx, state[0][i, 1] = channel_mixing(x_, state[0][i, 1], *params(f"blocks.{i}.ffn")) + x = x + dx + + x = layer_norm(x, *params("ln_out")) + logits = params("head")[0] @ x + + return logits, state + + +# Verification + +# Available at https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth +MODEL_FILE = "/home/colin/.cache/modelscope/hub/Blink_DL/rwkv-7-world/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth" + + +N_LAYER = 24 +N_EMBD = 1024 +HEAD_SIZE = 64 +N_HEAD = N_EMBD // HEAD_SIZE + +if 1: # Reference implementation + context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." + + # pip install rwkv + import os + + os.environ["RWKV_V7_ON"] = "1" + from rwkv.utils import PIPELINE + from rwkv.model import RWKV as referenceRWKV + + model = referenceRWKV(model=MODEL_FILE[:-4], strategy="cpu fp32") + pipeline = PIPELINE(model, "rwkv_vocab_v20230424") + tokens = pipeline.encode(context) + + reference_logits, state = model.forward(tokens, None) + reference_logits = reference_logits.numpy() + +weights = torch_load(MODEL_FILE, map_location="cpu", weights_only=True) +weights = {k: v.squeeze().float().numpy() for k, v in weights.items()} +params = lambda prefix: [weights[key] for key in weights.keys() if key.startswith(prefix)] + +state = ( + np.zeros((N_LAYER, 2, N_EMBD), dtype=np.float32), + np.zeros((N_LAYER, N_HEAD, HEAD_SIZE, HEAD_SIZE), dtype=np.float32), +) +for token in tokens: + minimal_logits, state = RWKV7(params, token, state) + +print("Deviation from official rwkv:", max(abs(minimal_logits - reference_logits)) / reference_logits.std()) diff --git a/test/tensor.py b/test/tensor.py index a239b34..570c0ff 100644 --- a/test/tensor.py +++ b/test/tensor.py @@ -1,5 +1,35 @@ import torch import torch.nn.functional as F +import torch.nn as nn + + +import torch +import torch.nn as nn + +# 假设输入是一个 batch 的序列数据,形状为 (batch_size, seq_len, hidden_dim) +batch_size, seq_len, hidden_dim = 2, 2, 4 +input_tensor = torch.randn(batch_size, seq_len, hidden_dim) + +# 定义 LayerNorm 层 +layer_norm1 = nn.LayerNorm(hidden_dim) +layer_norm2 = nn.LayerNorm(hidden_dim) + +# 应用 LayerNorm +output = layer_norm1(input_tensor) +print(input_tensor.numpy()) +print("\n") +print("\n") +print(output.detach().numpy()) + +output = layer_norm2(output) +print("\n") +print("\n") +print(output.detach().numpy()) + +x1 = torch.empty((1, 7, 768), dtype=float) +time_shift = nn.ZeroPad2d((1, 1, -1, 1)) +xx = time_shift(x1) + x1 = torch.tensor([[1, 2]], dtype=float) x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)