Refine rwkv.
This commit is contained in:
parent
240858c030
commit
821e7206b8
|
@ -14,5 +14,4 @@ K、V 就是等同于Transformer的Key与Value。
|
|||
TimeMix,指的是过去信息x-1与当前信息x的混合。 xx = self.time_shift(x) - x 这个是典型的操作
|
||||
|
||||
|
||||
|
||||
nn.Embedding
|
||||
RWKV_Tmix_x070
|
|
@ -173,7 +173,6 @@ class RWKV_Tmix_x070(Module):
|
|||
self.k_a = nn.Parameter(torch.empty(1, 1, C))
|
||||
self.r_k = nn.Parameter(torch.empty(H, N))
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
self.receptance = nn.Linear(C, C, bias=False)
|
||||
self.key = nn.Linear(C, C, bias=False)
|
||||
self.value = nn.Linear(C, C, bias=False)
|
||||
|
@ -181,64 +180,73 @@ class RWKV_Tmix_x070(Module):
|
|||
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
|
||||
|
||||
def forward(self, x, v_first):
|
||||
B, T, C = x.size()
|
||||
H = self.n_head
|
||||
xx = self.time_shift(x) - x
|
||||
B, T, C = x.size() # seq_len
|
||||
H = self.n_head # 12
|
||||
|
||||
xr = x + xx * self.x_r
|
||||
xw = x + xx * self.x_w
|
||||
xk = x + xx * self.x_k
|
||||
xv = x + xx * self.x_v
|
||||
xa = x + xx * self.x_a
|
||||
xg = x + xx * self.x_g
|
||||
shift = torch.zeros_like(x)
|
||||
shift[:, 1:, :] = x[:, :-1, :]
|
||||
|
||||
r = self.receptance(xr)
|
||||
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
xx = shift - x # time_shift [1, seq_len, 768] -> [1, seq_len, 768]
|
||||
|
||||
xr = x + xx * self.x_r # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
||||
xw = x + xx * self.x_w # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
||||
xk = x + xx * self.x_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
||||
xv = x + xx * self.x_v # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
||||
xa = x + xx * self.x_a # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
||||
xg = x + xx * self.x_g # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
||||
|
||||
r = self.receptance(xr) # [1, seq_len, 768] -> [1, seq_len, 768]
|
||||
# [1, seq_len, 768] -> [1, seq_len, 768]
|
||||
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # softplus clamp to (-inf, -0.5)
|
||||
k = self.key(xk) # [1, seq_len, 768] -> [1, seq_len, 768]
|
||||
v = self.value(xv) # [1, seq_len, 768] -> [1, seq_len, 768]
|
||||
if self.layer_id == 0:
|
||||
v_first = v # store the v of the first layer
|
||||
else:
|
||||
# -> [1, seq_len, 768]
|
||||
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
|
||||
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
|
||||
g = torch.sigmoid(xg @ self.g1) @ self.g2
|
||||
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # -> [1, seq_len, 768] # a is "in-context learning rate"
|
||||
g = torch.sigmoid(xg @ self.g1) @ self.g2 # -> [1, seq_len, 768]
|
||||
|
||||
kk = k * self.k_k
|
||||
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)
|
||||
k = k * (1 + (a - 1) * self.k_a)
|
||||
kk = k * self.k_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
||||
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # -> [1, seq_len, 768]
|
||||
k = k * (1 + (a - 1) * self.k_a) # -> [1, seq_len, 768]
|
||||
|
||||
def RWKV7_OP(r, w, k, v, a, b):
|
||||
B, T, C = r.size()
|
||||
H = C // HEAD_SIZE
|
||||
N = HEAD_SIZE
|
||||
r = r.view(B, T, H, N).float()
|
||||
k = k.view(B, T, H, N).float()
|
||||
v = v.view(B, T, H, N).float()
|
||||
a = a.view(B, T, H, N).float()
|
||||
b = b.view(B, T, H, N).float()
|
||||
w = torch.exp(-torch.exp(w.view(B, T, H, N).float()))
|
||||
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float)
|
||||
state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float)
|
||||
B, T, C = r.size() # 768
|
||||
H = C // HEAD_SIZE # 12
|
||||
N = HEAD_SIZE # 64
|
||||
r = r.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
|
||||
k = k.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
|
||||
v = v.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
|
||||
a = a.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
|
||||
b = b.view(B, T, H, N).float() # -> [1, seq_len, 12, 64]
|
||||
w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) # -> [1, seq_len, 12, 64]
|
||||
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
|
||||
state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
|
||||
|
||||
for t in range(T):
|
||||
kk = k[:, t, :].view(B, H, 1, N)
|
||||
rr = r[:, t, :].view(B, H, N, 1)
|
||||
vv = v[:, t, :].view(B, H, N, 1)
|
||||
aa = a[:, t, :].view(B, H, N, 1)
|
||||
bb = b[:, t, :].view(B, H, 1, N)
|
||||
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
|
||||
out[:, t, :] = (state @ rr).view(B, H, N)
|
||||
kk = k[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64]
|
||||
rr = r[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1]
|
||||
vv = v[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1]
|
||||
aa = a[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1]
|
||||
bb = b[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64]
|
||||
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk # -> [1, 12, 64, 64]
|
||||
out[:, t, :] = (state @ rr).view(B, H, N) # -> [1, seq_len, 12, 64]
|
||||
|
||||
return out.view(B, T, C).to(dtype=DTYPE)
|
||||
return out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768]
|
||||
|
||||
x = RWKV7_OP(r, w, k, v, -kk, kk * a)
|
||||
x = RWKV7_OP(r, w, k, v, -kk, kk * a) # -> [1, seq_len, 768]
|
||||
|
||||
x = self.ln_x(x.view(B * T, C)).view(B, T, C)
|
||||
x = self.ln_x(x.view(B * T, C)).view(B, T, C) # -> [1, seq_len, 768]
|
||||
|
||||
x = x + (
|
||||
(r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(dim=-1, keepdim=True) * v.view(B, T, H, -1)
|
||||
).view(B, T, C)
|
||||
x = self.output(x * g)
|
||||
xx = r.view(B, T, H, -1) * k.view(B, T, H, -1) # -> [1, seq_len, 12, 64]
|
||||
xx = xx * self.r_k # -> [1, seq_len, 12, 64]
|
||||
xx = xx.sum(dim=-1, keepdim=True) # -> [1, seq_len, 12, 1]
|
||||
xx = xx * v.view(B, T, H, -1) # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64]
|
||||
xx = xx.view(B, T, C) # -> [1, seq_len, 768]
|
||||
x = x + xx # -> [1, seq_len, 768]
|
||||
x = self.output(x * g) # -> [1, seq_len, 768]
|
||||
return x, v_first
|
||||
|
||||
|
||||
|
@ -252,7 +260,6 @@ class RWKV_CMix_x070(Module):
|
|||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad():
|
||||
self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd))
|
||||
|
@ -261,11 +268,15 @@ class RWKV_CMix_x070(Module):
|
|||
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x) - x
|
||||
|
||||
k = x + xx * self.x_k
|
||||
k = torch.relu(self.key(k)) ** 2
|
||||
return self.value(k)
|
||||
shift = torch.zeros_like(x)
|
||||
shift[:, 1:, :] = x[:, :-1, :]
|
||||
|
||||
xx = shift - x # time_shift -> [1, seq_len, 768]
|
||||
|
||||
k = x + xx * self.x_k # -> [1, seq_len, 768]
|
||||
k = torch.relu(self.key(k)) ** 2 # -> [1, seq_len, 768]
|
||||
return self.value(k) # -> [1, seq_len, 768]
|
||||
|
||||
|
||||
########################################################################################################
|
||||
|
@ -289,11 +300,12 @@ class Block(Module):
|
|||
def forward(self, x, v_first):
|
||||
|
||||
if self.layer_id == 0:
|
||||
x = self.ln0(x)
|
||||
x = self.ln0(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β
|
||||
ln = self.ln1(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β
|
||||
|
||||
xx, v_first = self.att(self.ln1(x), v_first)
|
||||
x = x + xx
|
||||
x = x + self.ffn(self.ln2(x))
|
||||
xx, v_first = self.att(ln, v_first) # [1, seq_len, 768] -> [1, seq_len, 768]
|
||||
x = x + xx # [1, seq_len, 768] -> [1, seq_len, 768]
|
||||
x = x + self.ffn(self.ln2(x)) # [1, seq_len, 768] -> [1, seq_len, 768]
|
||||
|
||||
return x, v_first
|
||||
|
||||
|
@ -317,14 +329,14 @@ class RWKV(nn.Module):
|
|||
|
||||
def forward(self, idx):
|
||||
|
||||
x = self.emb(idx)
|
||||
x = self.emb(idx) # [1, seq_len] -> [1, seq_len, 768]
|
||||
|
||||
v_first = torch.empty_like(x)
|
||||
v_first = torch.empty_like(x) # -> [1, seq_len, 768]
|
||||
for block in self.blocks:
|
||||
x, v_first = block(x, v_first)
|
||||
|
||||
x = self.ln_out(x)
|
||||
x = self.head(x)
|
||||
x = self.ln_out(x) # [1, seq_len, 768] -> [1, seq_len, 768]
|
||||
x = self.head(x) # [1, seq_len, 768] -> [1, seq_len, 65536]
|
||||
|
||||
return x
|
||||
|
||||
|
@ -346,6 +358,19 @@ with torch.no_grad():
|
|||
input = tokenizer.encode(prompt)
|
||||
print(f"\nInput:\n{input}")
|
||||
|
||||
# 中国的首都是在
|
||||
# 北 [probability 4.99%]
|
||||
# 中 [probability 4.22%]
|
||||
# 这 [probability 3.38%]
|
||||
# 上 [probability 2.74%]
|
||||
# 东 [probability 2.28%]
|
||||
# 台 [probability 2.23%]
|
||||
# 南 [probability 1.86%]
|
||||
# 广 [probability 1.83%]
|
||||
# 华 [probability 1.63%]
|
||||
# 河 [probability 1.47%]
|
||||
|
||||
|
||||
out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
|
||||
print(f"\nOutput:\n{out}")
|
||||
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
# RWKV-7 in numpy, by https://github.com/johanwind
|
||||
|
||||
import numpy as np
|
||||
from torch import load as torch_load
|
||||
|
||||
layer_norm = lambda x, w, b: (x - x.mean()) / (x.var() + 1e-5) ** 0.5 * w + b
|
||||
group_norm = (
|
||||
lambda x, w, b: ((x - x.mean(axis=1, keepdims=1)) / (x.var(axis=1, keepdims=1) + 64e-5) ** 0.5).flatten() * w + b
|
||||
)
|
||||
sigmoid = lambda x: 1 / (1 + np.exp(-x))
|
||||
|
||||
|
||||
def time_mixing(x, v0, last_x, S, params):
|
||||
mr, mw, mk, mv, ma, mg, w_bias, r_k, Ww1, Ww2, Wa1, Wa2, a_bias, Wg1, Wg2 = params[:15]
|
||||
k_k, k_a, Wr, Wk, Wv, Wo, ln_w, ln_b = params[-8:]
|
||||
|
||||
xr, xw, xk, xv, xa, xg = [x + m * (last_x - x) for m in [mr, mw, mk, mv, ma, mg]]
|
||||
|
||||
r = Wr @ xr
|
||||
w = np.exp(-sigmoid(np.tanh(xw @ Ww1) @ Ww2 + w_bias) / np.e**0.5)
|
||||
k = Wk @ xk
|
||||
v = Wv @ xv
|
||||
if v0 is None:
|
||||
v0 = v
|
||||
else:
|
||||
Wv2, Wv1, v_bias = params[15:18]
|
||||
v += (v0 - v) * sigmoid(xv @ Wv1 @ Wv2 + v_bias)
|
||||
a = sigmoid(xa @ Wa1 @ Wa2 + a_bias)
|
||||
g = sigmoid(xg @ Wg1) @ Wg2
|
||||
kk = k * k_k
|
||||
k += k * (a - 1) * k_a
|
||||
|
||||
r, w, k, v, kk, a, r_k = [i.reshape(N_HEAD, HEAD_SIZE, 1) for i in [r, w, k, v, kk, a, r_k]]
|
||||
kk /= np.maximum(np.linalg.norm(kk, axis=1, keepdims=1), 1e-12)
|
||||
|
||||
S = S * w.mT - S @ kk * (kk * a).mT + v * k.mT
|
||||
y = S @ r
|
||||
|
||||
y = group_norm(y, ln_w, ln_b)
|
||||
y += ((r * k * r_k).sum(axis=1, keepdims=1) * v).flatten()
|
||||
return Wo @ (y * g), v0, x, S
|
||||
|
||||
|
||||
def channel_mixing(x, last_x, mix, Wk, Wv):
|
||||
k = Wk @ (x + mix * (last_x - x))
|
||||
v = Wv @ np.maximum(k, 0) ** 2
|
||||
return v, x
|
||||
|
||||
|
||||
def RWKV7(params, token, state):
|
||||
x = params("emb")[0][token]
|
||||
x = layer_norm(x, *params("blocks.0.ln0"))
|
||||
|
||||
v0 = None
|
||||
for i in range(N_LAYER):
|
||||
x_ = layer_norm(x, *params(f"blocks.{i}.ln1"))
|
||||
dx, v0, state[0][i, 0], state[1][i] = time_mixing(
|
||||
x_, v0, state[0][i, 0], state[1][i], params(f"blocks.{i}.att")
|
||||
)
|
||||
x = x + dx
|
||||
|
||||
x_ = layer_norm(x, *params(f"blocks.{i}.ln2"))
|
||||
dx, state[0][i, 1] = channel_mixing(x_, state[0][i, 1], *params(f"blocks.{i}.ffn"))
|
||||
x = x + dx
|
||||
|
||||
x = layer_norm(x, *params("ln_out"))
|
||||
logits = params("head")[0] @ x
|
||||
|
||||
return logits, state
|
||||
|
||||
|
||||
# Verification
|
||||
|
||||
# Available at https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth
|
||||
MODEL_FILE = "/home/colin/.cache/modelscope/hub/Blink_DL/rwkv-7-world/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
|
||||
|
||||
|
||||
N_LAYER = 24
|
||||
N_EMBD = 1024
|
||||
HEAD_SIZE = 64
|
||||
N_HEAD = N_EMBD // HEAD_SIZE
|
||||
|
||||
if 1: # Reference implementation
|
||||
context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
|
||||
|
||||
# pip install rwkv
|
||||
import os
|
||||
|
||||
os.environ["RWKV_V7_ON"] = "1"
|
||||
from rwkv.utils import PIPELINE
|
||||
from rwkv.model import RWKV as referenceRWKV
|
||||
|
||||
model = referenceRWKV(model=MODEL_FILE[:-4], strategy="cpu fp32")
|
||||
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
||||
tokens = pipeline.encode(context)
|
||||
|
||||
reference_logits, state = model.forward(tokens, None)
|
||||
reference_logits = reference_logits.numpy()
|
||||
|
||||
weights = torch_load(MODEL_FILE, map_location="cpu", weights_only=True)
|
||||
weights = {k: v.squeeze().float().numpy() for k, v in weights.items()}
|
||||
params = lambda prefix: [weights[key] for key in weights.keys() if key.startswith(prefix)]
|
||||
|
||||
state = (
|
||||
np.zeros((N_LAYER, 2, N_EMBD), dtype=np.float32),
|
||||
np.zeros((N_LAYER, N_HEAD, HEAD_SIZE, HEAD_SIZE), dtype=np.float32),
|
||||
)
|
||||
for token in tokens:
|
||||
minimal_logits, state = RWKV7(params, token, state)
|
||||
|
||||
print("Deviation from official rwkv:", max(abs(minimal_logits - reference_logits)) / reference_logits.std())
|
|
@ -1,5 +1,35 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 假设输入是一个 batch 的序列数据,形状为 (batch_size, seq_len, hidden_dim)
|
||||
batch_size, seq_len, hidden_dim = 2, 2, 4
|
||||
input_tensor = torch.randn(batch_size, seq_len, hidden_dim)
|
||||
|
||||
# 定义 LayerNorm 层
|
||||
layer_norm1 = nn.LayerNorm(hidden_dim)
|
||||
layer_norm2 = nn.LayerNorm(hidden_dim)
|
||||
|
||||
# 应用 LayerNorm
|
||||
output = layer_norm1(input_tensor)
|
||||
print(input_tensor.numpy())
|
||||
print("\n")
|
||||
print("\n")
|
||||
print(output.detach().numpy())
|
||||
|
||||
output = layer_norm2(output)
|
||||
print("\n")
|
||||
print("\n")
|
||||
print(output.detach().numpy())
|
||||
|
||||
x1 = torch.empty((1, 7, 768), dtype=float)
|
||||
time_shift = nn.ZeroPad2d((1, 1, -1, 1))
|
||||
xx = time_shift(x1)
|
||||
|
||||
|
||||
x1 = torch.tensor([[1, 2]], dtype=float)
|
||||
x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)
|
||||
|
|
Loading…
Reference in New Issue