Format rwkv/RWKV-v7/rwkv_v7_demo.py

This commit is contained in:
Colin 2025-03-03 15:47:21 +08:00
parent 002f132818
commit 4f18296e40
1 changed files with 91 additions and 69 deletions

View File

@ -5,7 +5,9 @@
import torch, types, os, gc, math, json import torch, types, os, gc, math, json
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from torch.nn import Module
from torch.nn import functional as F from torch.nn import functional as F
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
@ -14,9 +16,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch._C._jit_set_autocast_mode(False) torch._C._jit_set_autocast_mode(False)
''' """
This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation) This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
''' """
args = types.SimpleNamespace() args = types.SimpleNamespace()
@ -42,28 +44,26 @@ HEAD_SIZE = args.head_size_a
USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
MyStatic = torch.jit.script
######################################################################################################## ########################################################################################################
# RWKV Tokenizer (slow version) # RWKV Tokenizer (slow version)
######################################################################################################## ########################################################################################################
class RWKV_TOKENIZER():
class RWKV_TOKENIZER:
table: list[list[list[bytes]]] table: list[list[list[bytes]]]
good: list[set[int]] good: list[set[int]]
wlen: list[int] wlen: list[int]
def __init__(self, file_name): def __init__(self, file_name):
self.idx2token = {} self.idx2token = {}
sorted = [] # must be already sorted sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines() lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines: for l in lines:
idx = int(l[:l.index(' ')]) idx = int(l[: l.index(" ")])
x = eval(l[l.index(' '):l.rindex(' ')]) x = eval(l[l.index(" ") : l.rindex(" ")])
x = x.encode("utf-8") if isinstance(x, str) else x x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes) assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):]) assert len(x) == int(l[l.rindex(" ") :])
sorted += [x] sorted += [x]
self.idx2token[idx] = x self.idx2token[idx] = x
@ -107,25 +107,26 @@ class RWKV_TOKENIZER():
return tokens return tokens
def decodeBytes(self, tokens): def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens)) return b"".join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src: str): def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8")) return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens): def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8') return self.decodeBytes(tokens).decode("utf-8")
def printTokens(self, tokens): def printTokens(self, tokens):
for i in tokens: for i in tokens:
s = self.idx2token[i] s = self.idx2token[i]
try: try:
s = s.decode('utf-8') s = s.decode("utf-8")
except: except:
pass pass
print(f'{repr(s)}{i}', end=' ') print(f"{repr(s)}{i}", end=" ")
# print(repr(s), i) # print(repr(s), i)
print() print()
tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt") tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt")
######################################################################################################## ########################################################################################################
@ -136,8 +137,21 @@ if USE_CUDA_KERNEL:
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
load(name="wkv7", sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"], is_python_module=False, load(
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"]) name="wkv7",
sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"],
is_python_module=False,
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={HEAD_SIZE}",
],
)
class WKV_7(torch.autograd.Function): class WKV_7(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, r, w, k, v, a, b): def forward(ctx, r, w, k, v, a, b):
@ -202,11 +216,13 @@ else:
return out.view(B, T, C).to(dtype=DTYPE) return out.view(B, T, C).to(dtype=DTYPE)
######################################################################################################## ########################################################################################################
# RWKV TimeMix # RWKV TimeMix
######################################################################################################## ########################################################################################################
class RWKV_Tmix_x070(MyModule):
class RWKV_Tmix_x070(Module):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
@ -253,7 +269,6 @@ class RWKV_Tmix_x070(MyModule):
self.output = nn.Linear(C, C, bias=False) self.output = nn.Linear(C, C, bias=False)
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!! self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
@MyFunction
def forward(self, x, v_first): def forward(self, x, v_first):
B, T, C = x.size() B, T, C = x.size()
H = self.n_head H = self.n_head
@ -284,15 +299,19 @@ class RWKV_Tmix_x070(MyModule):
x = RWKV7_OP(r, w, k, v, -kk, kk * a) x = RWKV7_OP(r, w, k, v, -kk, kk * a)
x = self.ln_x(x.view(B * T, C)).view(B, T, C) x = self.ln_x(x.view(B * T, C)).view(B, T, C)
x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C) x = x + (
(r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(dim=-1, keepdim=True) * v.view(B, T, H, -1)
).view(B, T, C)
x = self.output(x * g) x = self.output(x * g)
return x, v_first return x, v_first
######################################################################################################## ########################################################################################################
# RWKV ChannelMix # RWKV ChannelMix
######################################################################################################## ########################################################################################################
class RWKV_CMix_x070(MyModule):
class RWKV_CMix_x070(Module):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
@ -305,7 +324,6 @@ class RWKV_CMix_x070(MyModule):
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x): def forward(self, x):
xx = self.time_shift(x) - x xx = self.time_shift(x) - x
@ -313,11 +331,13 @@ class RWKV_CMix_x070(MyModule):
k = torch.relu(self.key(k)) ** 2 k = torch.relu(self.key(k)) ** 2
return self.value(k) return self.value(k)
######################################################################################################## ########################################################################################################
# RWKV Block # RWKV Block
######################################################################################################## ########################################################################################################
class Block(MyModule):
class Block(Module):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
@ -330,7 +350,6 @@ class Block(MyModule):
self.att = RWKV_Tmix_x070(args, layer_id) self.att = RWKV_Tmix_x070(args, layer_id)
self.ffn = RWKV_CMix_x070(args, layer_id) self.ffn = RWKV_CMix_x070(args, layer_id)
@MyFunction
def forward(self, x, v_first): def forward(self, x, v_first):
if self.layer_id == 0: if self.layer_id == 0:
@ -342,10 +361,12 @@ class Block(MyModule):
return x, v_first return x, v_first
######################################################################################################## ########################################################################################################
# RWKV Model # RWKV Model
######################################################################################################## ########################################################################################################
class RWKV(nn.Module): class RWKV(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
@ -371,6 +392,7 @@ class RWKV(nn.Module):
return x return x
######################################################################################################## ########################################################################################################
# RWKV Inference # RWKV Inference
######################################################################################################## ########################################################################################################
@ -386,32 +408,32 @@ with torch.no_grad():
prompt = "中国的首都是在" prompt = "中国的首都是在"
input = tokenizer.encode(prompt) input = tokenizer.encode(prompt)
print(f'\nInput:\n{input}') print(f"\nInput:\n{input}")
out = model.forward(torch.tensor(input).reshape(1, -1).cuda()) out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
print(f'\nOutput:\n{out}') print(f"\nOutput:\n{out}")
# logits of the last token => prediction for the next token # logits of the last token => prediction for the next token
out = out[0, -1] out = out[0, -1]
probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate) probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
print(f'\n{prompt}') print(f"\n{prompt}")
_, indices = torch.topk(probs, 10) # print top-10 possibilities _, indices = torch.topk(probs, 10) # print top-10 possibilities
for i in range(len(indices)): for i in range(len(indices)):
token_id = indices[i].item() token_id = indices[i].item()
token = tokenizer.decode([token_id]) token = tokenizer.decode([token_id])
token_prob = probs[token_id].item() token_prob = probs[token_id].item()
print(token, f'[probability {token_prob:.2%}]') print(token, f"[probability {token_prob:.2%}]")
######################################################################################################## ########################################################################################################
with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f: with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f:
todo = [json.loads(line) for line in f] todo = [json.loads(line) for line in f]
todo = [[doc['text'].rsplit(' ', 1)[0], " " + doc['text'].rsplit(' ', 1)[1]] for doc in todo] todo = [[doc["text"].rsplit(" ", 1)[0], " " + doc["text"].rsplit(" ", 1)[1]] for doc in todo]
print('\nCheck LAMBADA...') print("\nCheck LAMBADA...")
xsum = 0 xsum = 0
xcnt = 0 xcnt = 0
xacc = 0 xacc = 0
@ -433,4 +455,4 @@ with torch.no_grad():
xsum += logits xsum += logits
xacc += 1 if correct else 0 xacc += 1 if correct else 0
if xcnt % 100 == 0 or xcnt == len(todo): if xcnt % 100 == 0 or xcnt == len(todo):
print(xcnt, 'ppl', round(math.exp(-xsum / xcnt), 2), 'acc', round(xacc/xcnt*100, 2)) print(xcnt, "ppl", round(math.exp(-xsum / xcnt), 2), "acc", round(xacc / xcnt * 100, 2))