2025-03-03 14:53:15 +08:00
|
|
|
|
########################################################################################################
|
|
|
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
import torch, types, os, gc, math, json
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch.nn as nn
|
2025-03-03 15:47:21 +08:00
|
|
|
|
from torch.nn import Module
|
2025-03-03 14:53:15 +08:00
|
|
|
|
from torch.nn import functional as F
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
|
|
|
|
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
|
|
|
|
torch._C._jit_set_autocast_mode(False)
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
"""
|
2025-03-03 14:53:15 +08:00
|
|
|
|
This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
|
2025-03-03 15:47:21 +08:00
|
|
|
|
"""
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
args = types.SimpleNamespace()
|
|
|
|
|
|
|
|
|
|
# model download: https://huggingface.co/BlinkDL/rwkv-7-world
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "/home/colin/.cache/modelscope/hub/Blink_DL/rwkv-7-world/RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth"
|
|
|
|
|
|
|
|
|
|
# for 0.1B
|
|
|
|
|
args.n_layer = 12
|
|
|
|
|
args.n_embd = 768
|
|
|
|
|
D_DECAY_LORA = 64
|
|
|
|
|
D_AAA_LORA = 64
|
|
|
|
|
D_MV_LORA = 32
|
|
|
|
|
D_GATE_LORA = 128
|
|
|
|
|
|
|
|
|
|
args.vocab_size = 65536
|
|
|
|
|
|
|
|
|
|
# DTYPE = torch.bfloat16
|
2025-03-03 15:47:21 +08:00
|
|
|
|
DTYPE = torch.half # better
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
args.head_size_a = 64 # don't change
|
2025-03-06 23:22:50 +08:00
|
|
|
|
HS = args.head_size_a
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
# RWKV Tokenizer (slow version)
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
|
|
|
|
class RWKV_TOKENIZER:
|
2025-03-03 14:53:15 +08:00
|
|
|
|
table: list[list[list[bytes]]]
|
|
|
|
|
good: list[set[int]]
|
|
|
|
|
wlen: list[int]
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
def __init__(self, file_name):
|
|
|
|
|
self.idx2token = {}
|
2025-03-03 15:47:21 +08:00
|
|
|
|
sorted = [] # must be already sorted
|
2025-03-03 14:53:15 +08:00
|
|
|
|
lines = open(file_name, "r", encoding="utf-8").readlines()
|
|
|
|
|
for l in lines:
|
2025-03-03 15:47:21 +08:00
|
|
|
|
idx = int(l[: l.index(" ")])
|
|
|
|
|
x = eval(l[l.index(" ") : l.rindex(" ")])
|
2025-03-03 14:53:15 +08:00
|
|
|
|
x = x.encode("utf-8") if isinstance(x, str) else x
|
|
|
|
|
assert isinstance(x, bytes)
|
2025-03-03 15:47:21 +08:00
|
|
|
|
assert len(x) == int(l[l.rindex(" ") :])
|
2025-03-03 14:53:15 +08:00
|
|
|
|
sorted += [x]
|
|
|
|
|
self.idx2token[idx] = x
|
|
|
|
|
|
|
|
|
|
self.token2idx = {}
|
|
|
|
|
for k, v in self.idx2token.items():
|
|
|
|
|
self.token2idx[v] = int(k)
|
|
|
|
|
|
|
|
|
|
# precompute some tables for fast matching
|
|
|
|
|
self.table = [[[] for j in range(256)] for i in range(256)]
|
|
|
|
|
self.good = [set() for i in range(256)]
|
|
|
|
|
self.wlen = [0 for i in range(256)]
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
|
2025-03-03 14:53:15 +08:00
|
|
|
|
s = sorted[i]
|
|
|
|
|
if len(s) >= 2:
|
|
|
|
|
s0 = int(s[0])
|
|
|
|
|
s1 = int(s[1])
|
|
|
|
|
self.table[s0][s1] += [s]
|
|
|
|
|
self.wlen[s0] = max(self.wlen[s0], len(s))
|
|
|
|
|
self.good[s0].add(s1)
|
|
|
|
|
|
|
|
|
|
def encodeBytes(self, src: bytes) -> list[int]:
|
|
|
|
|
src_len: int = len(src)
|
|
|
|
|
tokens: list[int] = []
|
|
|
|
|
i: int = 0
|
|
|
|
|
while i < src_len:
|
|
|
|
|
s: bytes = src[i : i + 1]
|
|
|
|
|
|
|
|
|
|
if i < src_len - 1:
|
|
|
|
|
s1: int = int(src[i + 1])
|
|
|
|
|
s0: int = int(src[i])
|
|
|
|
|
if s1 in self.good[s0]:
|
|
|
|
|
sss: bytes = src[i : i + self.wlen[s0]]
|
|
|
|
|
try:
|
|
|
|
|
s = next(filter(sss.startswith, self.table[s0][s1]))
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
tokens.append(self.token2idx[s])
|
|
|
|
|
i += len(s)
|
|
|
|
|
|
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
|
def decodeBytes(self, tokens):
|
2025-03-03 15:47:21 +08:00
|
|
|
|
return b"".join(map(lambda i: self.idx2token[i], tokens))
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
def encode(self, src: str):
|
|
|
|
|
return self.encodeBytes(src.encode("utf-8"))
|
|
|
|
|
|
|
|
|
|
def decode(self, tokens):
|
2025-03-03 15:47:21 +08:00
|
|
|
|
return self.decodeBytes(tokens).decode("utf-8")
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
def printTokens(self, tokens):
|
|
|
|
|
for i in tokens:
|
|
|
|
|
s = self.idx2token[i]
|
|
|
|
|
try:
|
2025-03-03 15:47:21 +08:00
|
|
|
|
s = s.decode("utf-8")
|
2025-03-03 14:53:15 +08:00
|
|
|
|
except:
|
|
|
|
|
pass
|
2025-03-03 15:47:21 +08:00
|
|
|
|
print(f"{repr(s)}{i}", end=" ")
|
2025-03-03 14:53:15 +08:00
|
|
|
|
# print(repr(s), i)
|
|
|
|
|
print()
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt")
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
########################################################################################################
|
|
|
|
|
# RWKV TimeMix
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
|
|
|
|
class RWKV_Tmix_x070(Module):
|
2025-03-03 14:53:15 +08:00
|
|
|
|
def __init__(self, args, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
|
|
|
|
self.head_size = args.head_size_a
|
|
|
|
|
self.n_head = args.dim_att // self.head_size
|
|
|
|
|
assert args.dim_att % self.n_head == 0
|
|
|
|
|
|
|
|
|
|
H = self.n_head
|
2025-03-06 23:22:50 +08:00
|
|
|
|
HS = self.head_size
|
2025-03-03 14:53:15 +08:00
|
|
|
|
C = args.n_embd
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
self.x_r = nn.Parameter(torch.empty(1, 1, C))
|
|
|
|
|
self.x_w = nn.Parameter(torch.empty(1, 1, C))
|
|
|
|
|
self.x_k = nn.Parameter(torch.empty(1, 1, C))
|
|
|
|
|
self.x_v = nn.Parameter(torch.empty(1, 1, C))
|
|
|
|
|
self.x_a = nn.Parameter(torch.empty(1, 1, C))
|
|
|
|
|
self.x_g = nn.Parameter(torch.empty(1, 1, C))
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
self.w0 = nn.Parameter(torch.empty(1, 1, C))
|
2025-03-03 14:53:15 +08:00
|
|
|
|
self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
|
|
|
|
|
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
self.a0 = nn.Parameter(torch.empty(1, 1, C))
|
2025-03-03 14:53:15 +08:00
|
|
|
|
self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
|
|
|
|
|
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C))
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
self.v0 = nn.Parameter(torch.empty(1, 1, C))
|
2025-03-03 14:53:15 +08:00
|
|
|
|
self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA))
|
|
|
|
|
self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C))
|
|
|
|
|
|
|
|
|
|
self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
|
|
|
|
|
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C))
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
self.k_k = nn.Parameter(torch.empty(1, 1, C))
|
|
|
|
|
self.k_a = nn.Parameter(torch.empty(1, 1, C))
|
2025-03-06 23:22:50 +08:00
|
|
|
|
self.r_k = nn.Parameter(torch.empty(H, HS))
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
self.receptance = nn.Linear(C, C, bias=False)
|
|
|
|
|
self.key = nn.Linear(C, C, bias=False)
|
|
|
|
|
self.value = nn.Linear(C, C, bias=False)
|
|
|
|
|
self.output = nn.Linear(C, C, bias=False)
|
2025-03-03 15:47:21 +08:00
|
|
|
|
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
def forward(self, x, v_first):
|
2025-03-05 19:39:08 +08:00
|
|
|
|
B, T, C = x.size() # seq_len
|
|
|
|
|
H = self.n_head # 12
|
|
|
|
|
|
2025-03-06 23:22:50 +08:00
|
|
|
|
xx = torch.zeros_like(x) # time_shift [1, seq_len, 768] -> [1, seq_len, 768]
|
|
|
|
|
xx[:, 0, :] = -x[:, 0, :]
|
|
|
|
|
xx[:, 1:, :] = x[:, :-1, :] - x[:, 1:, :]
|
2025-03-05 19:39:08 +08:00
|
|
|
|
|
|
|
|
|
xr = x + xx * self.x_r # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
|
|
|
|
xw = x + xx * self.x_w # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
|
|
|
|
xk = x + xx * self.x_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
|
|
|
|
xv = x + xx * self.x_v # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
|
|
|
|
xa = x + xx * self.x_a # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
|
|
|
|
xg = x + xx * self.x_g # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
|
|
|
|
|
2025-03-06 23:22:50 +08:00
|
|
|
|
r = self.receptance(xr) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
|
|
|
|
|
|
|
|
|
|
xw = torch.tanh(xw @ self.w1) # -> [1, seq_len, 64]
|
|
|
|
|
xw = xw @ self.w2 + self.w0 # -> [1, seq_len, 768]
|
|
|
|
|
xw = -F.softplus(-xw) # 函数的输出范围为 [0,∞)
|
|
|
|
|
w = xw - 0.5 # -> [1, seq_len, 768]
|
|
|
|
|
|
|
|
|
|
k = self.key(xk) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
|
|
|
|
|
|
|
|
|
|
v = self.value(xv) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
|
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
if self.layer_id == 0:
|
2025-03-03 15:47:21 +08:00
|
|
|
|
v_first = v # store the v of the first layer
|
2025-03-03 14:53:15 +08:00
|
|
|
|
else:
|
2025-03-06 23:22:50 +08:00
|
|
|
|
xv = (xv @ self.v1) @ self.v2
|
|
|
|
|
xv = xv + self.v0
|
|
|
|
|
xv = torch.sigmoid(xv)
|
|
|
|
|
v = v + (v_first - v) * xv # add value residual # -> [1, seq_len, 768]
|
|
|
|
|
|
|
|
|
|
xa = (xa @ self.a1) @ self.a2
|
|
|
|
|
xa = xa + self.a0
|
|
|
|
|
a = torch.sigmoid(xa) # -> [1, seq_len, 768]
|
|
|
|
|
|
|
|
|
|
xg = xg @ self.g1
|
|
|
|
|
xg = torch.sigmoid(xg)
|
|
|
|
|
g = xg @ self.g2 # -> [1, seq_len, 768]
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-05 19:39:08 +08:00
|
|
|
|
kk = k * self.k_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
|
|
|
|
|
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # -> [1, seq_len, 768]
|
2025-03-06 23:22:50 +08:00
|
|
|
|
|
2025-03-05 19:39:08 +08:00
|
|
|
|
k = k * (1 + (a - 1) * self.k_a) # -> [1, seq_len, 768]
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-06 23:22:50 +08:00
|
|
|
|
# start op
|
|
|
|
|
a_op = -kk
|
|
|
|
|
b_op = kk * a
|
|
|
|
|
|
|
|
|
|
B, T, C = r.size() # 768
|
|
|
|
|
H = C // HS # 12
|
|
|
|
|
|
|
|
|
|
r_op = r.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
|
|
|
|
|
k_op = k.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64]
|
|
|
|
|
v_op = v.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
|
|
|
|
|
a_op = a_op.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
|
|
|
|
|
b_op = b_op.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64]
|
|
|
|
|
|
|
|
|
|
w_op = w.view(B, T, H, HS).float()
|
|
|
|
|
w_op = torch.exp(-torch.exp(w_op)) # -> [1, seq_len, 12, 64]
|
|
|
|
|
w_op = w_op.view(B, T, H, 1, HS) # -> [1, seq_len, 12, 1, 64]
|
|
|
|
|
|
|
|
|
|
out = torch.zeros((B, T, H, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
|
|
|
|
|
state = torch.zeros((B, H, HS, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
|
|
|
|
|
|
|
|
|
|
for t in range(T):
|
|
|
|
|
rr_op = r_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
|
|
|
|
|
kk_op = k_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64]
|
|
|
|
|
vv_op = v_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
|
|
|
|
|
aa_op = a_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
|
|
|
|
|
bb_op = b_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64]
|
|
|
|
|
ww_op = w_op[:, t, :] # [1, seq_len, 12, 64] -> [1, 12, 1, 64]
|
|
|
|
|
|
|
|
|
|
state = state * ww_op + state @ aa_op @ bb_op + vv_op @ kk_op # -> [1, 12, 64, 64]
|
|
|
|
|
out[:, t, :] = (state @ rr_op).view(B, H, HS) # -> [1, seq_len, 12, 64]
|
|
|
|
|
|
|
|
|
|
x = out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768]
|
|
|
|
|
# end op
|
2025-03-05 19:39:08 +08:00
|
|
|
|
|
|
|
|
|
x = self.ln_x(x.view(B * T, C)).view(B, T, C) # -> [1, seq_len, 768]
|
|
|
|
|
|
|
|
|
|
xx = r.view(B, T, H, -1) * k.view(B, T, H, -1) # -> [1, seq_len, 12, 64]
|
|
|
|
|
xx = xx * self.r_k # -> [1, seq_len, 12, 64]
|
|
|
|
|
xx = xx.sum(dim=-1, keepdim=True) # -> [1, seq_len, 12, 1]
|
|
|
|
|
xx = xx * v.view(B, T, H, -1) # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64]
|
|
|
|
|
xx = xx.view(B, T, C) # -> [1, seq_len, 768]
|
|
|
|
|
x = x + xx # -> [1, seq_len, 768]
|
2025-03-06 23:22:50 +08:00
|
|
|
|
x = self.output(x * g) # Linear -> [1, seq_len, 768]
|
2025-03-03 14:53:15 +08:00
|
|
|
|
return x, v_first
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
########################################################################################################
|
|
|
|
|
# RWKV ChannelMix
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
|
|
|
|
class RWKV_CMix_x070(Module):
|
2025-03-03 14:53:15 +08:00
|
|
|
|
def __init__(self, args, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd))
|
|
|
|
|
|
|
|
|
|
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
|
|
|
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-05 19:39:08 +08:00
|
|
|
|
shift = torch.zeros_like(x)
|
|
|
|
|
shift[:, 1:, :] = x[:, :-1, :]
|
|
|
|
|
|
|
|
|
|
xx = shift - x # time_shift -> [1, seq_len, 768]
|
|
|
|
|
|
|
|
|
|
k = x + xx * self.x_k # -> [1, seq_len, 768]
|
2025-03-06 23:22:50 +08:00
|
|
|
|
k = torch.relu(self.key(k)) ** 2 # Linear -> [1, seq_len, 768]
|
|
|
|
|
return self.value(k) # Linear -> [1, seq_len, 768]
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
########################################################################################################
|
|
|
|
|
# RWKV Block
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
|
|
|
|
class Block(Module):
|
2025-03-03 14:53:15 +08:00
|
|
|
|
def __init__(self, args, layer_id):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.args = args
|
|
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb
|
2025-03-03 14:53:15 +08:00
|
|
|
|
self.ln1 = nn.LayerNorm(args.n_embd)
|
|
|
|
|
self.ln2 = nn.LayerNorm(args.n_embd)
|
|
|
|
|
|
|
|
|
|
self.att = RWKV_Tmix_x070(args, layer_id)
|
|
|
|
|
self.ffn = RWKV_CMix_x070(args, layer_id)
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
def forward(self, x, v_first):
|
|
|
|
|
|
|
|
|
|
if self.layer_id == 0:
|
2025-03-06 23:22:50 +08:00
|
|
|
|
x = self.ln0(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β
|
|
|
|
|
ln = self.ln1(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-05 19:39:08 +08:00
|
|
|
|
xx, v_first = self.att(ln, v_first) # [1, seq_len, 768] -> [1, seq_len, 768]
|
|
|
|
|
x = x + xx # [1, seq_len, 768] -> [1, seq_len, 768]
|
|
|
|
|
x = x + self.ffn(self.ln2(x)) # [1, seq_len, 768] -> [1, seq_len, 768]
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
return x, v_first
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
########################################################################################################
|
|
|
|
|
# RWKV Model
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
class RWKV(nn.Module):
|
|
|
|
|
def __init__(self, args):
|
|
|
|
|
super().__init__()
|
|
|
|
|
args.dim_att = args.n_embd
|
|
|
|
|
args.dim_ffn = args.n_embd * 4
|
|
|
|
|
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
|
|
|
|
|
|
|
|
|
self.ln_out = nn.LayerNorm(args.n_embd)
|
|
|
|
|
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
def forward(self, idx):
|
|
|
|
|
|
2025-03-05 19:39:08 +08:00
|
|
|
|
x = self.emb(idx) # [1, seq_len] -> [1, seq_len, 768]
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-05 19:39:08 +08:00
|
|
|
|
v_first = torch.empty_like(x) # -> [1, seq_len, 768]
|
2025-03-03 14:53:15 +08:00
|
|
|
|
for block in self.blocks:
|
|
|
|
|
x, v_first = block(x, v_first)
|
|
|
|
|
|
2025-03-05 19:39:08 +08:00
|
|
|
|
x = self.ln_out(x) # [1, seq_len, 768] -> [1, seq_len, 768]
|
|
|
|
|
x = self.head(x) # [1, seq_len, 768] -> [1, seq_len, 65536]
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
|
2025-03-03 14:53:15 +08:00
|
|
|
|
########################################################################################################
|
|
|
|
|
# RWKV Inference
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
model_params = torch.load(MODEL_PATH, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
model = RWKV(args).to(dtype=DTYPE).cuda()
|
2025-03-03 15:47:21 +08:00
|
|
|
|
model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
prompt = "中国的首都是在"
|
|
|
|
|
input = tokenizer.encode(prompt)
|
2025-03-03 15:47:21 +08:00
|
|
|
|
print(f"\nInput:\n{input}")
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-06 23:22:50 +08:00
|
|
|
|
# 中国的首都是在
|
|
|
|
|
# 北 [probability 4.99%]
|
|
|
|
|
# 中 [probability 4.22%]
|
|
|
|
|
# 这 [probability 3.38%]
|
|
|
|
|
# 上 [probability 2.74%]
|
|
|
|
|
# 东 [probability 2.28%]
|
|
|
|
|
# 台 [probability 2.23%]
|
|
|
|
|
# 南 [probability 1.86%]
|
|
|
|
|
# 广 [probability 1.83%]
|
|
|
|
|
# 华 [probability 1.63%]
|
|
|
|
|
# 河 [probability 1.47%]
|
2025-03-05 19:39:08 +08:00
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
|
|
|
|
|
print(f"\nOutput:\n{out}")
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
# logits of the last token => prediction for the next token
|
2025-03-03 14:53:15 +08:00
|
|
|
|
out = out[0, -1]
|
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
|
|
|
|
|
|
|
|
|
|
print(f"\n{prompt}")
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
_, indices = torch.topk(probs, 10) # print top-10 possibilities
|
2025-03-03 14:53:15 +08:00
|
|
|
|
for i in range(len(indices)):
|
|
|
|
|
token_id = indices[i].item()
|
|
|
|
|
token = tokenizer.decode([token_id])
|
|
|
|
|
token_prob = probs[token_id].item()
|
2025-03-03 15:47:21 +08:00
|
|
|
|
print(token, f"[probability {token_prob:.2%}]")
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
|
|
|
|
with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f:
|
|
|
|
|
todo = [json.loads(line) for line in f]
|
2025-03-03 15:47:21 +08:00
|
|
|
|
todo = [[doc["text"].rsplit(" ", 1)[0], " " + doc["text"].rsplit(" ", 1)[1]] for doc in todo]
|
2025-03-03 14:53:15 +08:00
|
|
|
|
|
2025-03-03 15:47:21 +08:00
|
|
|
|
print("\nCheck LAMBADA...")
|
2025-03-03 14:53:15 +08:00
|
|
|
|
xsum = 0
|
|
|
|
|
xcnt = 0
|
|
|
|
|
xacc = 0
|
|
|
|
|
for d in todo:
|
|
|
|
|
src = [0] + tokenizer.encode(d[0])
|
|
|
|
|
dst = tokenizer.encode(d[1])
|
|
|
|
|
|
|
|
|
|
logits = 0
|
|
|
|
|
correct = True
|
2025-03-03 15:47:21 +08:00
|
|
|
|
out = model.forward(torch.tensor(src + dst).reshape(1, -1).cuda())
|
2025-03-03 14:53:15 +08:00
|
|
|
|
for i in range(len(dst)):
|
2025-03-03 15:47:21 +08:00
|
|
|
|
ooo = out[0, len(src) - 1 + i].float()
|
2025-03-03 14:53:15 +08:00
|
|
|
|
probs = F.softmax(ooo, dim=-1)
|
|
|
|
|
logits += math.log(probs[dst[i]])
|
|
|
|
|
if torch.argmax(probs).item() != dst[i]:
|
|
|
|
|
correct = False
|
|
|
|
|
|
|
|
|
|
xcnt += 1
|
|
|
|
|
xsum += logits
|
|
|
|
|
xacc += 1 if correct else 0
|
|
|
|
|
if xcnt % 100 == 0 or xcnt == len(todo):
|
2025-03-03 15:47:21 +08:00
|
|
|
|
print(xcnt, "ppl", round(math.exp(-xsum / xcnt), 2), "acc", round(xacc / xcnt * 100, 2))
|