Witllm/rwkv/RWKV-v7/rwkv_v7_demo.py

441 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import torch, types, os, gc, math, json
import numpy as np
import torch.nn as nn
from torch.nn import Module
from torch.nn import functional as F
np.set_printoptions(precision=4, suppress=True, linewidth=200)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch._C._jit_set_autocast_mode(False)
"""
This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
"""
args = types.SimpleNamespace()
# model download: https://huggingface.co/BlinkDL/rwkv-7-world
MODEL_PATH = "/home/colin/.cache/modelscope/hub/Blink_DL/rwkv-7-world/RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth"
# for 0.1B
args.n_layer = 12
args.n_embd = 768
D_DECAY_LORA = 64
D_AAA_LORA = 64
D_MV_LORA = 32
D_GATE_LORA = 128
args.vocab_size = 65536
# DTYPE = torch.bfloat16
DTYPE = torch.half # better
args.head_size_a = 64 # don't change
HS = args.head_size_a
########################################################################################################
# RWKV Tokenizer (slow version)
########################################################################################################
class RWKV_TOKENIZER:
table: list[list[list[bytes]]]
good: list[set[int]]
wlen: list[int]
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines:
idx = int(l[: l.index(" ")])
x = eval(l[l.index(" ") : l.rindex(" ")])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(" ") :])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)
# precompute some tables for fast matching
self.table = [[[] for j in range(256)] for i in range(256)]
self.good = [set() for i in range(256)]
self.wlen = [0 for i in range(256)]
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
s = sorted[i]
if len(s) >= 2:
s0 = int(s[0])
s1 = int(s[1])
self.table[s0][s1] += [s]
self.wlen[s0] = max(self.wlen[s0], len(s))
self.good[s0].add(s1)
def encodeBytes(self, src: bytes) -> list[int]:
src_len: int = len(src)
tokens: list[int] = []
i: int = 0
while i < src_len:
s: bytes = src[i : i + 1]
if i < src_len - 1:
s1: int = int(src[i + 1])
s0: int = int(src[i])
if s1 in self.good[s0]:
sss: bytes = src[i : i + self.wlen[s0]]
try:
s = next(filter(sss.startswith, self.table[s0][s1]))
except:
pass
tokens.append(self.token2idx[s])
i += len(s)
return tokens
def decodeBytes(self, tokens):
return b"".join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens):
return self.decodeBytes(tokens).decode("utf-8")
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode("utf-8")
except:
pass
print(f"{repr(s)}{i}", end=" ")
# print(repr(s), i)
print()
tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt")
########################################################################################################
# RWKV TimeMix
########################################################################################################
class RWKV_Tmix_x070(Module):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.head_size = args.head_size_a
self.n_head = args.dim_att // self.head_size
assert args.dim_att % self.n_head == 0
H = self.n_head
HS = self.head_size
C = args.n_embd
self.x_r = nn.Parameter(torch.empty(1, 1, C))
self.x_w = nn.Parameter(torch.empty(1, 1, C))
self.x_k = nn.Parameter(torch.empty(1, 1, C))
self.x_v = nn.Parameter(torch.empty(1, 1, C))
self.x_a = nn.Parameter(torch.empty(1, 1, C))
self.x_g = nn.Parameter(torch.empty(1, 1, C))
self.w0 = nn.Parameter(torch.empty(1, 1, C))
self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))
self.a0 = nn.Parameter(torch.empty(1, 1, C))
self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C))
self.v0 = nn.Parameter(torch.empty(1, 1, C))
self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA))
self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C))
self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C))
self.k_k = nn.Parameter(torch.empty(1, 1, C))
self.k_a = nn.Parameter(torch.empty(1, 1, C))
self.r_k = nn.Parameter(torch.empty(H, HS))
self.receptance = nn.Linear(C, C, bias=False)
self.key = nn.Linear(C, C, bias=False)
self.value = nn.Linear(C, C, bias=False)
self.output = nn.Linear(C, C, bias=False)
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
def forward(self, x, v_first):
B, T, C = x.size() # seq_len
H = self.n_head # 12
xx = torch.zeros_like(x) # time_shift [1, seq_len, 768] -> [1, seq_len, 768]
xx[:, 0, :] = -x[:, 0, :]
xx[:, 1:, :] = x[:, :-1, :] - x[:, 1:, :]
xr = x + xx * self.x_r # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xw = x + xx * self.x_w # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xk = x + xx * self.x_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xv = x + xx * self.x_v # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xa = x + xx * self.x_a # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xg = x + xx * self.x_g # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
r = self.receptance(xr) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
xw = torch.tanh(xw @ self.w1) # -> [1, seq_len, 64]
xw = xw @ self.w2 + self.w0 # -> [1, seq_len, 768]
xw = -F.softplus(-xw) # 函数的输出范围为 [0,∞)
w = xw - 0.5 # -> [1, seq_len, 768]
k = self.key(xk) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
v = self.value(xv) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
if self.layer_id == 0:
v_first = v # store the v of the first layer
else:
xv = (xv @ self.v1) @ self.v2
xv = xv + self.v0
xv = torch.sigmoid(xv)
v = v + (v_first - v) * xv # add value residual # -> [1, seq_len, 768]
xa = (xa @ self.a1) @ self.a2
xa = xa + self.a0
a = torch.sigmoid(xa) # -> [1, seq_len, 768]
xg = xg @ self.g1
xg = torch.sigmoid(xg)
g = xg @ self.g2 # -> [1, seq_len, 768]
kk = k * self.k_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # -> [1, seq_len, 768]
k = k * (1 + (a - 1) * self.k_a) # -> [1, seq_len, 768]
# start op
a_op = -kk
b_op = kk * a
B, T, C = r.size() # 768
H = C // HS # 12
r_op = r.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
k_op = k.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64]
v_op = v.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
a_op = a_op.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
b_op = b_op.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64]
w_op = w.view(B, T, H, HS).float()
w_op = torch.exp(-torch.exp(w_op)) # -> [1, seq_len, 12, 64]
w_op = w_op.view(B, T, H, 1, HS) # -> [1, seq_len, 12, 1, 64]
out = torch.zeros((B, T, H, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
state = torch.zeros((B, H, HS, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
for t in range(T):
rr_op = r_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
kk_op = k_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64]
vv_op = v_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
aa_op = a_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
bb_op = b_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64]
ww_op = w_op[:, t, :] # [1, seq_len, 12, 64] -> [1, 12, 1, 64]
state = state * ww_op + state @ aa_op @ bb_op + vv_op @ kk_op # -> [1, 12, 64, 64]
out[:, t, :] = (state @ rr_op).view(B, H, HS) # -> [1, seq_len, 12, 64]
x = out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768]
# end op
x = self.ln_x(x.view(B * T, C)).view(B, T, C) # -> [1, seq_len, 768]
xx = r.view(B, T, H, -1) * k.view(B, T, H, -1) # -> [1, seq_len, 12, 64]
xx = xx * self.r_k # -> [1, seq_len, 12, 64]
xx = xx.sum(dim=-1, keepdim=True) # -> [1, seq_len, 12, 1]
xx = xx * v.view(B, T, H, -1) # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64]
xx = xx.view(B, T, C) # -> [1, seq_len, 768]
x = x + xx # -> [1, seq_len, 768]
x = self.output(x * g) # Linear -> [1, seq_len, 768]
return x, v_first
########################################################################################################
# RWKV ChannelMix
########################################################################################################
class RWKV_CMix_x070(Module):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
with torch.no_grad():
self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd))
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
def forward(self, x):
shift = torch.zeros_like(x)
shift[:, 1:, :] = x[:, :-1, :]
xx = shift - x # time_shift -> [1, seq_len, 768]
k = x + xx * self.x_k # -> [1, seq_len, 768]
k = torch.relu(self.key(k)) ** 2 # Linear -> [1, seq_len, 768]
return self.value(k) # Linear -> [1, seq_len, 768]
########################################################################################################
# RWKV Block
########################################################################################################
class Block(Module):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb
self.ln1 = nn.LayerNorm(args.n_embd)
self.ln2 = nn.LayerNorm(args.n_embd)
self.att = RWKV_Tmix_x070(args, layer_id)
self.ffn = RWKV_CMix_x070(args, layer_id)
def forward(self, x, v_first):
if self.layer_id == 0:
x = self.ln0(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β
ln = self.ln1(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β
xx, v_first = self.att(ln, v_first) # [1, seq_len, 768] -> [1, seq_len, 768]
x = x + xx # [1, seq_len, 768] -> [1, seq_len, 768]
x = x + self.ffn(self.ln2(x)) # [1, seq_len, 768] -> [1, seq_len, 768]
return x, v_first
########################################################################################################
# RWKV Model
########################################################################################################
class RWKV(nn.Module):
def __init__(self, args):
super().__init__()
args.dim_att = args.n_embd
args.dim_ffn = args.n_embd * 4
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
self.ln_out = nn.LayerNorm(args.n_embd)
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
def forward(self, idx):
x = self.emb(idx) # [1, seq_len] -> [1, seq_len, 768]
v_first = torch.empty_like(x) # -> [1, seq_len, 768]
for block in self.blocks:
x, v_first = block(x, v_first)
x = self.ln_out(x) # [1, seq_len, 768] -> [1, seq_len, 768]
x = self.head(x) # [1, seq_len, 768] -> [1, seq_len, 65536]
return x
########################################################################################################
# RWKV Inference
########################################################################################################
model_params = torch.load(MODEL_PATH, map_location="cpu")
with torch.no_grad():
model = RWKV(args).to(dtype=DTYPE).cuda()
model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2
########################################################################################################
prompt = "中国的首都是在"
input = tokenizer.encode(prompt)
print(f"\nInput:\n{input}")
# 中国的首都是在
# 北 [probability 4.99%]
# 中 [probability 4.22%]
# 这 [probability 3.38%]
# 上 [probability 2.74%]
# 东 [probability 2.28%]
# 台 [probability 2.23%]
# 南 [probability 1.86%]
# 广 [probability 1.83%]
# 华 [probability 1.63%]
# 河 [probability 1.47%]
out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
print(f"\nOutput:\n{out}")
# logits of the last token => prediction for the next token
out = out[0, -1]
probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
print(f"\n{prompt}")
_, indices = torch.topk(probs, 10) # print top-10 possibilities
for i in range(len(indices)):
token_id = indices[i].item()
token = tokenizer.decode([token_id])
token_prob = probs[token_id].item()
print(token, f"[probability {token_prob:.2%}]")
########################################################################################################
with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f:
todo = [json.loads(line) for line in f]
todo = [[doc["text"].rsplit(" ", 1)[0], " " + doc["text"].rsplit(" ", 1)[1]] for doc in todo]
print("\nCheck LAMBADA...")
xsum = 0
xcnt = 0
xacc = 0
for d in todo:
src = [0] + tokenizer.encode(d[0])
dst = tokenizer.encode(d[1])
logits = 0
correct = True
out = model.forward(torch.tensor(src + dst).reshape(1, -1).cuda())
for i in range(len(dst)):
ooo = out[0, len(src) - 1 + i].float()
probs = F.softmax(ooo, dim=-1)
logits += math.log(probs[dst[i]])
if torch.argmax(probs).item() != dst[i]:
correct = False
xcnt += 1
xsum += logits
xacc += 1 if correct else 0
if xcnt % 100 == 0 or xcnt == len(todo):
print(xcnt, "ppl", round(math.exp(-xsum / xcnt), 2), "acc", round(xacc / xcnt * 100, 2))