Compare commits

...

2 Commits

Author SHA1 Message Date
Colin 240858c030 Update rwkv. 2025-03-03 21:30:58 +08:00
Colin 4f18296e40 Format rwkv/RWKV-v7/rwkv_v7_demo.py 2025-03-03 15:47:21 +08:00
6 changed files with 116 additions and 289 deletions

View File

@ -1,55 +0,0 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::Half bf16;
// typedef at::BFloat16 bf16;
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
F *__restrict__ const _y)
{
const int e = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
float state[_N_] = {0};
__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
for (int _t = 0; _t < T; _t++)
{
const int t = e*T*C + h*_N_ + i + _t * C;
__syncthreads();
r[i] = float(_r[t]);
w[i] = __expf(-__expf(float(_w[t])));
k[i] = float(_k[t]);
a[i] = float(_a[t]);
b[i] = float(_b[t]);
__syncthreads();
float sa = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
sa += a[j] * state[j];
}
float vv = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = state[j];
s = s * w[j] + k[j] * vv + sa * b[j];
y += s * r[j];
}
_y[t] = F(y);
}
}
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
{
assert(H*_N_ == C);
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, w, k, v, a, b, y);
}

View File

@ -1,15 +0,0 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::Half bf16;
// typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
}
TORCH_LIBRARY(wkv7, m) {
m.def("forward", forward);
}

View File

@ -1,64 +0,0 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::Half bf16;
// typedef at::BFloat16 bf16;
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
F *__restrict__ const _y)
{
const int e = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
float state[_N_];
#pragma unroll
for (int j = 0; j < _N_; j++)
state[j] = _state[j];
__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
for (int _t = 0; _t < T; _t++)
{
const int t = e*T*C + h*_N_ + i + _t * C;
__syncthreads();
r[i] = float(_r[t]);
w[i] = __expf(-__expf(float(_w[t])));
k[i] = float(_k[t]);
a[i] = float(_a[t]);
b[i] = float(_b[t]);
__syncthreads();
float sa = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
sa += a[j] * state[j];
}
float vv = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = state[j];
s = s * w[j] + k[j] * vv + sa * b[j];
y += s * r[j];
}
_y[t] = F(y);
}
#pragma unroll
for (int j = 0; j < _N_; j++)
_state[j] = state[j];
}
void cuda_forward(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
{
assert(H*_N_ == C);
assert(B == 1); // only for B=1
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
}

View File

@ -1,15 +0,0 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::Half bf16;
// typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
cuda_forward(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
}
TORCH_LIBRARY(wkv7s, m) {
m.def("forward", forward);
}

18
rwkv/RWKV-v7/model.md Normal file
View File

@ -0,0 +1,18 @@
R-Receptance 这个接受度可以从代码上直接看到,它是模型对过去的记忆程度。
W-Weight 这个Weight本身并不是一个泛指是一个过去信息的时间衰减
K、V 就是等同于Transformer的Key与Value。
- 记住过去的信息(通过 V
- 找到相关的信息(通过 K
- 控制信息的重要性(通过 W
- 决定使用多少信息(通过 R
TimeMix指的是过去信息x-1与当前信息x的混合。 xx = self.time_shift(x) - x 这个是典型的操作
nn.Embedding

View File

@ -5,7 +5,9 @@
import torch, types, os, gc, math, json import torch, types, os, gc, math, json
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from torch.nn import Module
from torch.nn import functional as F from torch.nn import functional as F
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
@ -14,9 +16,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch._C._jit_set_autocast_mode(False) torch._C._jit_set_autocast_mode(False)
''' """
This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation) This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
''' """
args = types.SimpleNamespace() args = types.SimpleNamespace()
@ -35,35 +37,31 @@ D_GATE_LORA = 128
args.vocab_size = 65536 args.vocab_size = 65536
# DTYPE = torch.bfloat16 # DTYPE = torch.bfloat16
DTYPE = torch.half # better DTYPE = torch.half # better
args.head_size_a = 64 # don't change args.head_size_a = 64 # don't change
HEAD_SIZE = args.head_size_a HEAD_SIZE = args.head_size_a
USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
MyStatic = torch.jit.script
######################################################################################################## ########################################################################################################
# RWKV Tokenizer (slow version) # RWKV Tokenizer (slow version)
######################################################################################################## ########################################################################################################
class RWKV_TOKENIZER():
class RWKV_TOKENIZER:
table: list[list[list[bytes]]] table: list[list[list[bytes]]]
good: list[set[int]] good: list[set[int]]
wlen: list[int] wlen: list[int]
def __init__(self, file_name): def __init__(self, file_name):
self.idx2token = {} self.idx2token = {}
sorted = [] # must be already sorted sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines() lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines: for l in lines:
idx = int(l[:l.index(' ')]) idx = int(l[: l.index(" ")])
x = eval(l[l.index(' '):l.rindex(' ')]) x = eval(l[l.index(" ") : l.rindex(" ")])
x = x.encode("utf-8") if isinstance(x, str) else x x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes) assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):]) assert len(x) == int(l[l.rindex(" ") :])
sorted += [x] sorted += [x]
self.idx2token[idx] = x self.idx2token[idx] = x
@ -76,7 +74,7 @@ class RWKV_TOKENIZER():
self.good = [set() for i in range(256)] self.good = [set() for i in range(256)]
self.wlen = [0 for i in range(256)] self.wlen = [0 for i in range(256)]
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
s = sorted[i] s = sorted[i]
if len(s) >= 2: if len(s) >= 2:
s0 = int(s[0]) s0 = int(s[0])
@ -107,106 +105,35 @@ class RWKV_TOKENIZER():
return tokens return tokens
def decodeBytes(self, tokens): def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens)) return b"".join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src: str): def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8")) return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens): def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8') return self.decodeBytes(tokens).decode("utf-8")
def printTokens(self, tokens): def printTokens(self, tokens):
for i in tokens: for i in tokens:
s = self.idx2token[i] s = self.idx2token[i]
try: try:
s = s.decode('utf-8') s = s.decode("utf-8")
except: except:
pass pass
print(f'{repr(s)}{i}', end=' ') print(f"{repr(s)}{i}", end=" ")
# print(repr(s), i) # print(repr(s), i)
print() print()
tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt") tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt")
########################################################################################################
# CUDA Kernel
########################################################################################################
if USE_CUDA_KERNEL:
from torch.utils.cpp_extension import load
load(name="wkv7", sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"], is_python_module=False,
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
class WKV_7(torch.autograd.Function):
@staticmethod
def forward(ctx, r, w, k, v, a, b):
with torch.no_grad():
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
assert HEAD_SIZE == C // H
assert r.dtype == DTYPE
assert w.dtype == DTYPE
assert k.dtype == DTYPE
assert v.dtype == DTYPE
assert a.dtype == DTYPE
assert b.dtype == DTYPE
assert r.is_contiguous()
assert w.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert a.is_contiguous()
assert b.is_contiguous()
y = torch.empty((B, T, C), device=k.device, dtype=DTYPE, memory_format=torch.contiguous_format)
torch.ops.wkv7.forward(B, T, C, H, r, w, k, v, a, b, y)
return y
def RWKV7_OP(r, w, k, v, a, b):
return WKV_7.apply(r, w, k, v, a, b)
else:
def RWKV7_OP(r, w, k, v, a, b):
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
r = r.view(B, T, H, N).float()
k = k.view(B, T, H, N).float()
v = v.view(B, T, H, N).float()
a = a.view(B, T, H, N).float()
b = b.view(B, T, H, N).float()
w = torch.exp(-torch.exp(w.view(B, T, H, N).float()))
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float)
state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float)
for t in range(T):
kk = k[:, t, :].view(B, H, 1, N)
rr = r[:, t, :].view(B, H, N, 1)
vv = v[:, t, :].view(B, H, N, 1)
aa = a[:, t, :].view(B, H, N, 1)
bb = b[:, t, :].view(B, H, 1, N)
state = state * w[: , t, :, None, :] + state @ aa @ bb + vv @ kk
out[:, t, :] = (state @ rr).view(B, H, N)
# another method using einsum
#
# kk = k[:, t, :]
# rr = r[:, t, :]
# vv = v[:, t, :]
# aa = a[:, t, :]
# bb = b[:, t, :]
# sab = torch.einsum('bhik,bhk,bhj->bhij', state, aa, bb)
# state = state * w[: , t, :, None, :] + sab + torch.einsum('bhj,bhi->bhij', kk, vv)
# out[:, t, :] = torch.einsum('bhj,bhij->bhi', rr, state)
return out.view(B, T, C).to(dtype=DTYPE)
######################################################################################################## ########################################################################################################
# RWKV TimeMix # RWKV TimeMix
######################################################################################################## ########################################################################################################
class RWKV_Tmix_x070(MyModule):
class RWKV_Tmix_x070(Module):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
@ -220,40 +147,39 @@ class RWKV_Tmix_x070(MyModule):
N = self.head_size N = self.head_size
C = args.n_embd C = args.n_embd
self.x_r = nn.Parameter(torch.empty(1,1,C)) self.x_r = nn.Parameter(torch.empty(1, 1, C))
self.x_w = nn.Parameter(torch.empty(1,1,C)) self.x_w = nn.Parameter(torch.empty(1, 1, C))
self.x_k = nn.Parameter(torch.empty(1,1,C)) self.x_k = nn.Parameter(torch.empty(1, 1, C))
self.x_v = nn.Parameter(torch.empty(1,1,C)) self.x_v = nn.Parameter(torch.empty(1, 1, C))
self.x_a = nn.Parameter(torch.empty(1,1,C)) self.x_a = nn.Parameter(torch.empty(1, 1, C))
self.x_g = nn.Parameter(torch.empty(1,1,C)) self.x_g = nn.Parameter(torch.empty(1, 1, C))
self.w0 = nn.Parameter(torch.empty(1,1,C)) self.w0 = nn.Parameter(torch.empty(1, 1, C))
self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA)) self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C)) self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))
self.a0 = nn.Parameter(torch.empty(1,1,C)) self.a0 = nn.Parameter(torch.empty(1, 1, C))
self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA)) self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C)) self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C))
self.v0 = nn.Parameter(torch.empty(1,1,C)) self.v0 = nn.Parameter(torch.empty(1, 1, C))
self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA)) self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA))
self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C)) self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C))
self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA)) self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C)) self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C))
self.k_k = nn.Parameter(torch.empty(1,1,C)) self.k_k = nn.Parameter(torch.empty(1, 1, C))
self.k_a = nn.Parameter(torch.empty(1,1,C)) self.k_a = nn.Parameter(torch.empty(1, 1, C))
self.r_k = nn.Parameter(torch.empty(H,N)) self.r_k = nn.Parameter(torch.empty(H, N))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(C, C, bias=False) self.receptance = nn.Linear(C, C, bias=False)
self.key = nn.Linear(C, C, bias=False) self.key = nn.Linear(C, C, bias=False)
self.value = nn.Linear(C, C, bias=False) self.value = nn.Linear(C, C, bias=False)
self.output = nn.Linear(C, C, bias=False) self.output = nn.Linear(C, C, bias=False)
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!! self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
@MyFunction
def forward(self, x, v_first): def forward(self, x, v_first):
B, T, C = x.size() B, T, C = x.size()
H = self.n_head H = self.n_head
@ -267,32 +193,61 @@ class RWKV_Tmix_x070(MyModule):
xg = x + xx * self.x_g xg = x + xx * self.x_g
r = self.receptance(xr) r = self.receptance(xr)
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5) w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
k = self.key(xk) k = self.key(xk)
v = self.value(xv) v = self.value(xv)
if self.layer_id == 0: if self.layer_id == 0:
v_first = v # store the v of the first layer v_first = v # store the v of the first layer
else: else:
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate" a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.g1) @ self.g2 g = torch.sigmoid(xg @ self.g1) @ self.g2
kk = k * self.k_k kk = k * self.k_k
kk = F.normalize(kk.view(B,T,H,-1), dim=-1, p=2.0).view(B,T,C) kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)
k = k * (1 + (a-1) * self.k_a) k = k * (1 + (a - 1) * self.k_a)
def RWKV7_OP(r, w, k, v, a, b):
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
r = r.view(B, T, H, N).float()
k = k.view(B, T, H, N).float()
v = v.view(B, T, H, N).float()
a = a.view(B, T, H, N).float()
b = b.view(B, T, H, N).float()
w = torch.exp(-torch.exp(w.view(B, T, H, N).float()))
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float)
state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float)
for t in range(T):
kk = k[:, t, :].view(B, H, 1, N)
rr = r[:, t, :].view(B, H, N, 1)
vv = v[:, t, :].view(B, H, N, 1)
aa = a[:, t, :].view(B, H, N, 1)
bb = b[:, t, :].view(B, H, 1, N)
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
out[:, t, :] = (state @ rr).view(B, H, N)
return out.view(B, T, C).to(dtype=DTYPE)
x = RWKV7_OP(r, w, k, v, -kk, kk * a)
x = RWKV7_OP(r, w, k, v, -kk, kk*a)
x = self.ln_x(x.view(B * T, C)).view(B, T, C) x = self.ln_x(x.view(B * T, C)).view(B, T, C)
x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C) x = x + (
(r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(dim=-1, keepdim=True) * v.view(B, T, H, -1)
).view(B, T, C)
x = self.output(x * g) x = self.output(x * g)
return x, v_first return x, v_first
######################################################################################################## ########################################################################################################
# RWKV ChannelMix # RWKV ChannelMix
######################################################################################################## ########################################################################################################
class RWKV_CMix_x070(MyModule):
class RWKV_CMix_x070(Module):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
@ -305,32 +260,32 @@ class RWKV_CMix_x070(MyModule):
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x): def forward(self, x):
xx = self.time_shift(x) - x xx = self.time_shift(x) - x
k = x + xx * self.x_k k = x + xx * self.x_k
k = torch.relu(self.key(k)) ** 2 k = torch.relu(self.key(k)) ** 2
return self.value(k) return self.value(k)
######################################################################################################## ########################################################################################################
# RWKV Block # RWKV Block
######################################################################################################## ########################################################################################################
class Block(MyModule):
class Block(Module):
def __init__(self, args, layer_id): def __init__(self, args, layer_id):
super().__init__() super().__init__()
self.args = args self.args = args
self.layer_id = layer_id self.layer_id = layer_id
self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb
self.ln1 = nn.LayerNorm(args.n_embd) self.ln1 = nn.LayerNorm(args.n_embd)
self.ln2 = nn.LayerNorm(args.n_embd) self.ln2 = nn.LayerNorm(args.n_embd)
self.att = RWKV_Tmix_x070(args, layer_id) self.att = RWKV_Tmix_x070(args, layer_id)
self.ffn = RWKV_CMix_x070(args, layer_id) self.ffn = RWKV_CMix_x070(args, layer_id)
@MyFunction
def forward(self, x, v_first): def forward(self, x, v_first):
if self.layer_id == 0: if self.layer_id == 0:
@ -342,10 +297,12 @@ class Block(MyModule):
return x, v_first return x, v_first
######################################################################################################## ########################################################################################################
# RWKV Model # RWKV Model
######################################################################################################## ########################################################################################################
class RWKV(nn.Module): class RWKV(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
@ -371,6 +328,7 @@ class RWKV(nn.Module):
return x return x
######################################################################################################## ########################################################################################################
# RWKV Inference # RWKV Inference
######################################################################################################## ########################################################################################################
@ -380,38 +338,38 @@ model_params = torch.load(MODEL_PATH, map_location="cpu")
with torch.no_grad(): with torch.no_grad():
model = RWKV(args).to(dtype=DTYPE).cuda() model = RWKV(args).to(dtype=DTYPE).cuda()
model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2 model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2
######################################################################################################## ########################################################################################################
prompt = "中国的首都是在" prompt = "中国的首都是在"
input = tokenizer.encode(prompt) input = tokenizer.encode(prompt)
print(f'\nInput:\n{input}') print(f"\nInput:\n{input}")
out = model.forward(torch.tensor(input).reshape(1,-1).cuda()) out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
print(f'\nOutput:\n{out}') print(f"\nOutput:\n{out}")
# logits of the last token => prediction for the next token # logits of the last token => prediction for the next token
out = out[0, -1] out = out[0, -1]
probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
print(f'\n{prompt}') probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
_, indices = torch.topk(probs, 10) # print top-10 possibilities print(f"\n{prompt}")
_, indices = torch.topk(probs, 10) # print top-10 possibilities
for i in range(len(indices)): for i in range(len(indices)):
token_id = indices[i].item() token_id = indices[i].item()
token = tokenizer.decode([token_id]) token = tokenizer.decode([token_id])
token_prob = probs[token_id].item() token_prob = probs[token_id].item()
print(token, f'[probability {token_prob:.2%}]') print(token, f"[probability {token_prob:.2%}]")
######################################################################################################## ########################################################################################################
with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f: with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f:
todo = [json.loads(line) for line in f] todo = [json.loads(line) for line in f]
todo = [[doc['text'].rsplit(' ', 1)[0], " " + doc['text'].rsplit(' ', 1)[1]] for doc in todo] todo = [[doc["text"].rsplit(" ", 1)[0], " " + doc["text"].rsplit(" ", 1)[1]] for doc in todo]
print('\nCheck LAMBADA...') print("\nCheck LAMBADA...")
xsum = 0 xsum = 0
xcnt = 0 xcnt = 0
xacc = 0 xacc = 0
@ -421,9 +379,9 @@ with torch.no_grad():
logits = 0 logits = 0
correct = True correct = True
out = model.forward(torch.tensor(src+dst).reshape(1,-1).cuda()) out = model.forward(torch.tensor(src + dst).reshape(1, -1).cuda())
for i in range(len(dst)): for i in range(len(dst)):
ooo = out[0,len(src)-1+i].float() ooo = out[0, len(src) - 1 + i].float()
probs = F.softmax(ooo, dim=-1) probs = F.softmax(ooo, dim=-1)
logits += math.log(probs[dst[i]]) logits += math.log(probs[dst[i]])
if torch.argmax(probs).item() != dst[i]: if torch.argmax(probs).item() != dst[i]:
@ -433,4 +391,4 @@ with torch.no_grad():
xsum += logits xsum += logits
xacc += 1 if correct else 0 xacc += 1 if correct else 0
if xcnt % 100 == 0 or xcnt == len(todo): if xcnt % 100 == 0 or xcnt == len(todo):
print(xcnt, 'ppl', round(math.exp(-xsum / xcnt), 2), 'acc', round(xacc/xcnt*100, 2)) print(xcnt, "ppl", round(math.exp(-xsum / xcnt), 2), "acc", round(xacc / xcnt * 100, 2))