115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
########################################################################################################
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
########################################################################################################
|
|
# RWKV-7 in numpy, by https://github.com/johanwind
|
|
|
|
import numpy as np
|
|
from torch import load as torch_load
|
|
|
|
layer_norm = lambda x, w, b: (x - x.mean()) / (x.var() + 1e-5) ** 0.5 * w + b
|
|
group_norm = (
|
|
lambda x, w, b: ((x - x.mean(axis=1, keepdims=1)) / (x.var(axis=1, keepdims=1) + 64e-5) ** 0.5).flatten() * w + b
|
|
)
|
|
sigmoid = lambda x: 1 / (1 + np.exp(-x))
|
|
|
|
|
|
def time_mixing(x, v0, last_x, S, params):
|
|
mr, mw, mk, mv, ma, mg, w_bias, r_k, Ww1, Ww2, Wa1, Wa2, a_bias, Wg1, Wg2 = params[:15]
|
|
k_k, k_a, Wr, Wk, Wv, Wo, ln_w, ln_b = params[-8:]
|
|
|
|
xr, xw, xk, xv, xa, xg = [x + m * (last_x - x) for m in [mr, mw, mk, mv, ma, mg]]
|
|
|
|
r = Wr @ xr
|
|
w = np.exp(-sigmoid(np.tanh(xw @ Ww1) @ Ww2 + w_bias) / np.e**0.5)
|
|
k = Wk @ xk
|
|
v = Wv @ xv
|
|
if v0 is None:
|
|
v0 = v
|
|
else:
|
|
Wv2, Wv1, v_bias = params[15:18]
|
|
v += (v0 - v) * sigmoid(xv @ Wv1 @ Wv2 + v_bias)
|
|
a = sigmoid(xa @ Wa1 @ Wa2 + a_bias)
|
|
g = sigmoid(xg @ Wg1) @ Wg2
|
|
kk = k * k_k
|
|
k += k * (a - 1) * k_a
|
|
|
|
r, w, k, v, kk, a, r_k = [i.reshape(N_HEAD, HEAD_SIZE, 1) for i in [r, w, k, v, kk, a, r_k]]
|
|
kk /= np.maximum(np.linalg.norm(kk, axis=1, keepdims=1), 1e-12)
|
|
|
|
S = S * w.mT - S @ kk * (kk * a).mT + v * k.mT
|
|
y = S @ r
|
|
|
|
y = group_norm(y, ln_w, ln_b)
|
|
y += ((r * k * r_k).sum(axis=1, keepdims=1) * v).flatten()
|
|
return Wo @ (y * g), v0, x, S
|
|
|
|
|
|
def channel_mixing(x, last_x, mix, Wk, Wv):
|
|
k = Wk @ (x + mix * (last_x - x))
|
|
v = Wv @ np.maximum(k, 0) ** 2
|
|
return v, x
|
|
|
|
|
|
def RWKV7(params, token, state):
|
|
x = params("emb")[0][token]
|
|
x = layer_norm(x, *params("blocks.0.ln0"))
|
|
|
|
v0 = None
|
|
for i in range(N_LAYER):
|
|
x_ = layer_norm(x, *params(f"blocks.{i}.ln1"))
|
|
dx, v0, state[0][i, 0], state[1][i] = time_mixing(
|
|
x_, v0, state[0][i, 0], state[1][i], params(f"blocks.{i}.att")
|
|
)
|
|
x = x + dx
|
|
|
|
x_ = layer_norm(x, *params(f"blocks.{i}.ln2"))
|
|
dx, state[0][i, 1] = channel_mixing(x_, state[0][i, 1], *params(f"blocks.{i}.ffn"))
|
|
x = x + dx
|
|
|
|
x = layer_norm(x, *params("ln_out"))
|
|
logits = params("head")[0] @ x
|
|
|
|
return logits, state
|
|
|
|
|
|
# Verification
|
|
|
|
# Available at https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth
|
|
MODEL_FILE = "/home/colin/.cache/modelscope/hub/Blink_DL/rwkv-7-world/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
|
|
|
|
|
|
N_LAYER = 24
|
|
N_EMBD = 1024
|
|
HEAD_SIZE = 64
|
|
N_HEAD = N_EMBD // HEAD_SIZE
|
|
|
|
if 1: # Reference implementation
|
|
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
|
|
|
|
# pip install rwkv
|
|
import os
|
|
|
|
os.environ["RWKV_V7_ON"] = "1"
|
|
from rwkv.utils import PIPELINE
|
|
from rwkv.model import RWKV as referenceRWKV
|
|
|
|
model = referenceRWKV(model=MODEL_FILE[:-4], strategy="cpu fp32")
|
|
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
|
tokens = pipeline.encode(context)
|
|
|
|
reference_logits, state = model.forward(tokens, None)
|
|
reference_logits = reference_logits.numpy()
|
|
|
|
weights = torch_load(MODEL_FILE, map_location="cpu", weights_only=True)
|
|
weights = {k: v.squeeze().float().numpy() for k, v in weights.items()}
|
|
params = lambda prefix: [weights[key] for key in weights.keys() if key.startswith(prefix)]
|
|
|
|
state = (
|
|
np.zeros((N_LAYER, 2, N_EMBD), dtype=np.float32),
|
|
np.zeros((N_LAYER, N_HEAD, HEAD_SIZE, HEAD_SIZE), dtype=np.float32),
|
|
)
|
|
for token in tokens:
|
|
minimal_logits, state = RWKV7(params, token, state)
|
|
|
|
print("Deviation from official rwkv:", max(abs(minimal_logits - reference_logits)) / reference_logits.std())
|