Format rwkv/RWKV-v7/rwkv_v7_demo.py

This commit is contained in:
Colin 2025-03-03 15:47:21 +08:00
parent 002f132818
commit 4f18296e40
1 changed files with 91 additions and 69 deletions

View File

@ -5,7 +5,9 @@
import torch, types, os, gc, math, json import torch, types, os, gc, math, json
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from torch.nn import Module
from torch.nn import functional as F from torch.nn import functional as F
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
@ -14,9 +16,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch._C._jit_set_autocast_mode(False) torch._C._jit_set_autocast_mode(False)
''' """
This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation) This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
''' """
args = types.SimpleNamespace() args = types.SimpleNamespace()
@ -35,35 +37,33 @@ D_GATE_LORA = 128
args.vocab_size = 65536 args.vocab_size = 65536
# DTYPE = torch.bfloat16 # DTYPE = torch.bfloat16
DTYPE = torch.half # better DTYPE = torch.half # better
args.head_size_a = 64 # don't change args.head_size_a = 64 # don't change
HEAD_SIZE = args.head_size_a HEAD_SIZE = args.head_size_a
USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
MyStatic = torch.jit.script
######################################################################################################## ########################################################################################################
# RWKV Tokenizer (slow version) # RWKV Tokenizer (slow version)
######################################################################################################## ########################################################################################################
class RWKV_TOKENIZER():
class RWKV_TOKENIZER:
table: list[list[list[bytes]]] table: list[list[list[bytes]]]
good: list[set[int]] good: list[set[int]]
wlen: list[int] wlen: list[int]
def __init__(self, file_name): def __init__(self, file_name):
self.idx2token = {} self.idx2token = {}
sorted = [] # must be already sorted sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines() lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines: for l in lines:
idx = int(l[:l.index(' ')]) idx = int(l[: l.index(" ")])
x = eval(l[l.index(' '):l.rindex(' ')]) x = eval(l[l.index(" ") : l.rindex(" ")])
x = x.encode("utf-8") if isinstance(x, str) else x x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes) assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):]) assert len(x) == int(l[l.rindex(" ") :])
sorted += [x] sorted += [x]
self.idx2token[idx] = x self.idx2token[idx] = x
@ -76,7 +76,7 @@ class RWKV_TOKENIZER():
self.good = [set() for i in range(256)] self.good = [set() for i in range(256)]
self.wlen = [0 for i in range(256)] self.wlen = [0 for i in range(256)]
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
s = sorted[i] s = sorted[i]
if len(s) >= 2: if len(s) >= 2:
s0 = int(s[0]) s0 = int(s[0])
@ -107,25 +107,26 @@ class RWKV_TOKENIZER():
return tokens return tokens
def decodeBytes(self, tokens): def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens)) return b"".join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src: str): def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8")) return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens): def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8') return self.decodeBytes(tokens).decode("utf-8")
def printTokens(self, tokens): def printTokens(self, tokens):
for i in tokens: for i in tokens:
s = self.idx2token[i] s = self.idx2token[i]
try: try:
s = s.decode('utf-8') s = s.decode("utf-8")
except: except:
pass pass
print(f'{repr(s)}{i}', end=' ') print(f"{repr(s)}{i}", end=" ")
# print(repr(s), i) # print(repr(s), i)
print() print()
tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt") tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt")
######################################################################################################## ########################################################################################################
@ -136,8 +137,21 @@ if USE_CUDA_KERNEL:
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
load(name="wkv7", sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"], is_python_module=False, load(
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"]) name="wkv7",
sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"],
is_python_module=False,
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={HEAD_SIZE}",
],
)
class WKV_7(torch.autograd.Function): class WKV_7(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, r, w, k, v, a, b): def forward(ctx, r, w, k, v, a, b):
@ -186,7 +200,7 @@ else:
vv = v[:, t, :].view(B, H, N, 1) vv = v[:, t, :].view(B, H, N, 1)
aa = a[:, t, :].view(B, H, N, 1) aa = a[:, t, :].view(B, H, N, 1)
bb = b[:, t, :].view(B, H, 1, N) bb = b[:, t, :].view(B, H, 1, N)
state = state * w[: , t, :, None, :] + state @ aa @ bb + vv @ kk state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
out[:, t, :] = (state @ rr).view(B, H, N) out[:, t, :] = (state @ rr).view(B, H, N)
# another method using einsum # another method using einsum
@ -202,11 +216,13 @@ else:
return out.view(B, T, C).to(dtype=DTYPE) return out.view(B, T, C).to(dtype=DTYPE)
######################################################################################################## ########################################################################################################
# RWKV TimeMix # RWKV TimeMix
######################################################################################################## ########################################################################################################
class RWKV_Tmix_x070(MyModule):
class RWKV_Tmix_x070(Module):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
@ -220,40 +236,39 @@ class RWKV_Tmix_x070(MyModule):
N = self.head_size N = self.head_size
C = args.n_embd C = args.n_embd
self.x_r = nn.Parameter(torch.empty(1,1,C)) self.x_r = nn.Parameter(torch.empty(1, 1, C))
self.x_w = nn.Parameter(torch.empty(1,1,C)) self.x_w = nn.Parameter(torch.empty(1, 1, C))
self.x_k = nn.Parameter(torch.empty(1,1,C)) self.x_k = nn.Parameter(torch.empty(1, 1, C))
self.x_v = nn.Parameter(torch.empty(1,1,C)) self.x_v = nn.Parameter(torch.empty(1, 1, C))
self.x_a = nn.Parameter(torch.empty(1,1,C)) self.x_a = nn.Parameter(torch.empty(1, 1, C))
self.x_g = nn.Parameter(torch.empty(1,1,C)) self.x_g = nn.Parameter(torch.empty(1, 1, C))
self.w0 = nn.Parameter(torch.empty(1,1,C)) self.w0 = nn.Parameter(torch.empty(1, 1, C))
self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA)) self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C)) self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))
self.a0 = nn.Parameter(torch.empty(1,1,C)) self.a0 = nn.Parameter(torch.empty(1, 1, C))
self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA)) self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C)) self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C))
self.v0 = nn.Parameter(torch.empty(1,1,C)) self.v0 = nn.Parameter(torch.empty(1, 1, C))
self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA)) self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA))
self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C)) self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C))
self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA)) self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C)) self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C))
self.k_k = nn.Parameter(torch.empty(1,1,C)) self.k_k = nn.Parameter(torch.empty(1, 1, C))
self.k_a = nn.Parameter(torch.empty(1,1,C)) self.k_a = nn.Parameter(torch.empty(1, 1, C))
self.r_k = nn.Parameter(torch.empty(H,N)) self.r_k = nn.Parameter(torch.empty(H, N))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(C, C, bias=False) self.receptance = nn.Linear(C, C, bias=False)
self.key = nn.Linear(C, C, bias=False) self.key = nn.Linear(C, C, bias=False)
self.value = nn.Linear(C, C, bias=False) self.value = nn.Linear(C, C, bias=False)
self.output = nn.Linear(C, C, bias=False) self.output = nn.Linear(C, C, bias=False)
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!! self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
@MyFunction
def forward(self, x, v_first): def forward(self, x, v_first):
B, T, C = x.size() B, T, C = x.size()
H = self.n_head H = self.n_head
@ -267,32 +282,36 @@ class RWKV_Tmix_x070(MyModule):
xg = x + xx * self.x_g xg = x + xx * self.x_g
r = self.receptance(xr) r = self.receptance(xr)
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5) w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
k = self.key(xk) k = self.key(xk)
v = self.value(xv) v = self.value(xv)
if self.layer_id == 0: if self.layer_id == 0:
v_first = v # store the v of the first layer v_first = v # store the v of the first layer
else: else:
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate" a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.g1) @ self.g2 g = torch.sigmoid(xg @ self.g1) @ self.g2
kk = k * self.k_k kk = k * self.k_k
kk = F.normalize(kk.view(B,T,H,-1), dim=-1, p=2.0).view(B,T,C) kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)
k = k * (1 + (a-1) * self.k_a) k = k * (1 + (a - 1) * self.k_a)
x = RWKV7_OP(r, w, k, v, -kk, kk*a) x = RWKV7_OP(r, w, k, v, -kk, kk * a)
x = self.ln_x(x.view(B * T, C)).view(B, T, C) x = self.ln_x(x.view(B * T, C)).view(B, T, C)
x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C) x = x + (
(r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(dim=-1, keepdim=True) * v.view(B, T, H, -1)
).view(B, T, C)
x = self.output(x * g) x = self.output(x * g)
return x, v_first return x, v_first
######################################################################################################## ########################################################################################################
# RWKV ChannelMix # RWKV ChannelMix
######################################################################################################## ########################################################################################################
class RWKV_CMix_x070(MyModule):
class RWKV_CMix_x070(Module):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
@ -305,32 +324,32 @@ class RWKV_CMix_x070(MyModule):
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x): def forward(self, x):
xx = self.time_shift(x) - x xx = self.time_shift(x) - x
k = x + xx * self.x_k k = x + xx * self.x_k
k = torch.relu(self.key(k)) ** 2 k = torch.relu(self.key(k)) ** 2
return self.value(k) return self.value(k)
######################################################################################################## ########################################################################################################
# RWKV Block # RWKV Block
######################################################################################################## ########################################################################################################
class Block(MyModule):
class Block(Module):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
self.layer_id = layer_id self.layer_id = layer_id
self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb
self.ln1 = nn.LayerNorm(args.n_embd) self.ln1 = nn.LayerNorm(args.n_embd)
self.ln2 = nn.LayerNorm(args.n_embd) self.ln2 = nn.LayerNorm(args.n_embd)
self.att = RWKV_Tmix_x070(args, layer_id) self.att = RWKV_Tmix_x070(args, layer_id)
self.ffn = RWKV_CMix_x070(args, layer_id) self.ffn = RWKV_CMix_x070(args, layer_id)
@MyFunction
def forward(self, x, v_first): def forward(self, x, v_first):
if self.layer_id == 0: if self.layer_id == 0:
@ -342,10 +361,12 @@ class Block(MyModule):
return x, v_first return x, v_first
######################################################################################################## ########################################################################################################
# RWKV Model # RWKV Model
######################################################################################################## ########################################################################################################
class RWKV(nn.Module): class RWKV(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
@ -371,6 +392,7 @@ class RWKV(nn.Module):
return x return x
######################################################################################################## ########################################################################################################
# RWKV Inference # RWKV Inference
######################################################################################################## ########################################################################################################
@ -380,38 +402,38 @@ model_params = torch.load(MODEL_PATH, map_location="cpu")
with torch.no_grad(): with torch.no_grad():
model = RWKV(args).to(dtype=DTYPE).cuda() model = RWKV(args).to(dtype=DTYPE).cuda()
model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2 model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2
######################################################################################################## ########################################################################################################
prompt = "中国的首都是在" prompt = "中国的首都是在"
input = tokenizer.encode(prompt) input = tokenizer.encode(prompt)
print(f'\nInput:\n{input}') print(f"\nInput:\n{input}")
out = model.forward(torch.tensor(input).reshape(1,-1).cuda()) out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
print(f'\nOutput:\n{out}') print(f"\nOutput:\n{out}")
# logits of the last token => prediction for the next token # logits of the last token => prediction for the next token
out = out[0, -1] out = out[0, -1]
probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
print(f'\n{prompt}') probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
_, indices = torch.topk(probs, 10) # print top-10 possibilities print(f"\n{prompt}")
_, indices = torch.topk(probs, 10) # print top-10 possibilities
for i in range(len(indices)): for i in range(len(indices)):
token_id = indices[i].item() token_id = indices[i].item()
token = tokenizer.decode([token_id]) token = tokenizer.decode([token_id])
token_prob = probs[token_id].item() token_prob = probs[token_id].item()
print(token, f'[probability {token_prob:.2%}]') print(token, f"[probability {token_prob:.2%}]")
######################################################################################################## ########################################################################################################
with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f: with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f:
todo = [json.loads(line) for line in f] todo = [json.loads(line) for line in f]
todo = [[doc['text'].rsplit(' ', 1)[0], " " + doc['text'].rsplit(' ', 1)[1]] for doc in todo] todo = [[doc["text"].rsplit(" ", 1)[0], " " + doc["text"].rsplit(" ", 1)[1]] for doc in todo]
print('\nCheck LAMBADA...') print("\nCheck LAMBADA...")
xsum = 0 xsum = 0
xcnt = 0 xcnt = 0
xacc = 0 xacc = 0
@ -421,9 +443,9 @@ with torch.no_grad():
logits = 0 logits = 0
correct = True correct = True
out = model.forward(torch.tensor(src+dst).reshape(1,-1).cuda()) out = model.forward(torch.tensor(src + dst).reshape(1, -1).cuda())
for i in range(len(dst)): for i in range(len(dst)):
ooo = out[0,len(src)-1+i].float() ooo = out[0, len(src) - 1 + i].float()
probs = F.softmax(ooo, dim=-1) probs = F.softmax(ooo, dim=-1)
logits += math.log(probs[dst[i]]) logits += math.log(probs[dst[i]])
if torch.argmax(probs).item() != dst[i]: if torch.argmax(probs).item() != dst[i]:
@ -433,4 +455,4 @@ with torch.no_grad():
xsum += logits xsum += logits
xacc += 1 if correct else 0 xacc += 1 if correct else 0
if xcnt % 100 == 0 or xcnt == len(todo): if xcnt % 100 == 0 or xcnt == len(todo):
print(xcnt, 'ppl', round(math.exp(-xsum / xcnt), 2), 'acc', round(xacc/xcnt*100, 2)) print(xcnt, "ppl", round(math.exp(-xsum / xcnt), 2), "acc", round(xacc / xcnt * 100, 2))