diff --git a/rwkv/RWKV-v7/cuda/wkv7.cu b/rwkv/RWKV-v7/cuda/wkv7.cu deleted file mode 100644 index 5a390f9..0000000 --- a/rwkv/RWKV-v7/cuda/wkv7.cu +++ /dev/null @@ -1,55 +0,0 @@ -#include -#include -#include "ATen/ATen.h" - -typedef at::Half bf16; -// typedef at::BFloat16 bf16; - -template -__global__ void kernel_forward(const int B, const int T, const int C, const int H, - const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b, - F *__restrict__ const _y) -{ - const int e = blockIdx.x / H; - const int h = blockIdx.x % H; - const int i = threadIdx.x; - - float state[_N_] = {0}; - __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_]; - - for (int _t = 0; _t < T; _t++) - { - const int t = e*T*C + h*_N_ + i + _t * C; - __syncthreads(); - r[i] = float(_r[t]); - w[i] = __expf(-__expf(float(_w[t]))); - k[i] = float(_k[t]); - a[i] = float(_a[t]); - b[i] = float(_b[t]); - __syncthreads(); - - float sa = 0; - #pragma unroll - for (int j = 0; j < _N_; j++) - { - sa += a[j] * state[j]; - } - - float vv = float(_v[t]); - float y = 0; - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& s = state[j]; - s = s * w[j] + k[j] * vv + sa * b[j]; - y += s * r[j]; - } - _y[t] = F(y); - } -} - -void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y) -{ - assert(H*_N_ == C); - kernel_forward<<>>(B, T, C, H, r, w, k, v, a, b, y); -} diff --git a/rwkv/RWKV-v7/cuda/wkv7_op.cpp b/rwkv/RWKV-v7/cuda/wkv7_op.cpp deleted file mode 100644 index 1885ec1..0000000 --- a/rwkv/RWKV-v7/cuda/wkv7_op.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include -#include "ATen/ATen.h" - -typedef at::Half bf16; -// typedef at::BFloat16 bf16; - -void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y); - -void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) { - cuda_forward(B, T, C, H, r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr()); -} - -TORCH_LIBRARY(wkv7, m) { - m.def("forward", forward); -} diff --git a/rwkv/RWKV-v7/cuda/wkv7s.cu b/rwkv/RWKV-v7/cuda/wkv7s.cu deleted file mode 100644 index 633f3b4..0000000 --- a/rwkv/RWKV-v7/cuda/wkv7s.cu +++ /dev/null @@ -1,64 +0,0 @@ -#include -#include -#include "ATen/ATen.h" - -typedef at::Half bf16; -// typedef at::BFloat16 bf16; - -template -__global__ void kernel_forward(const int B, const int T, const int C, const int H, - float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b, - F *__restrict__ const _y) -{ - const int e = blockIdx.x / H; - const int h = blockIdx.x % H; - const int i = threadIdx.x; - _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! - - float state[_N_]; - #pragma unroll - for (int j = 0; j < _N_; j++) - state[j] = _state[j]; - - __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_]; - - for (int _t = 0; _t < T; _t++) - { - const int t = e*T*C + h*_N_ + i + _t * C; - __syncthreads(); - r[i] = float(_r[t]); - w[i] = __expf(-__expf(float(_w[t]))); - k[i] = float(_k[t]); - a[i] = float(_a[t]); - b[i] = float(_b[t]); - __syncthreads(); - - float sa = 0; - #pragma unroll - for (int j = 0; j < _N_; j++) - { - sa += a[j] * state[j]; - } - - float vv = float(_v[t]); - float y = 0; - #pragma unroll - for (int j = 0; j < _N_; j++) - { - float& s = state[j]; - s = s * w[j] + k[j] * vv + sa * b[j]; - y += s * r[j]; - } - _y[t] = F(y); - } - #pragma unroll - for (int j = 0; j < _N_; j++) - _state[j] = state[j]; -} - -void cuda_forward(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y) -{ - assert(H*_N_ == C); - assert(B == 1); // only for B=1 - kernel_forward<<>>(B, T, C, H, state, r, w, k, v, a, b, y); -} diff --git a/rwkv/RWKV-v7/cuda/wkv7s_op.cpp b/rwkv/RWKV-v7/cuda/wkv7s_op.cpp deleted file mode 100644 index 2a84092..0000000 --- a/rwkv/RWKV-v7/cuda/wkv7s_op.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include -#include "ATen/ATen.h" - -typedef at::Half bf16; -// typedef at::BFloat16 bf16; - -void cuda_forward(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y); - -void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) { - cuda_forward(B, T, C, H, state.data_ptr(), r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr()); -} - -TORCH_LIBRARY(wkv7s, m) { - m.def("forward", forward); -} diff --git a/rwkv/RWKV-v7/model.md b/rwkv/RWKV-v7/model.md new file mode 100644 index 0000000..8e52d89 --- /dev/null +++ b/rwkv/RWKV-v7/model.md @@ -0,0 +1,18 @@ + + + +R-Receptance 这个接受度可以从代码上直接看到,它是模型对过去的记忆程度。 +W-Weight 这个Weight本身并不是一个泛指,是一个过去信息的时间衰减 +K、V 就是等同于Transformer的Key与Value。 + +- 记住过去的信息(通过 V) +- 找到相关的信息(通过 K) +- 控制信息的重要性(通过 W) +- 决定使用多少信息(通过 R) + + +TimeMix,指的是过去信息x-1与当前信息x的混合。 xx = self.time_shift(x) - x 这个是典型的操作 + + + +nn.Embedding \ No newline at end of file diff --git a/rwkv/RWKV-v7/rwkv_v7_demo.py b/rwkv/RWKV-v7/rwkv_v7_demo.py index 5ab5b25..f8f2c53 100644 --- a/rwkv/RWKV-v7/rwkv_v7_demo.py +++ b/rwkv/RWKV-v7/rwkv_v7_demo.py @@ -42,8 +42,6 @@ DTYPE = torch.half # better args.head_size_a = 64 # don't change HEAD_SIZE = args.head_size_a -USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW - ######################################################################################################## # RWKV Tokenizer (slow version) ######################################################################################################## @@ -129,93 +127,6 @@ class RWKV_TOKENIZER: tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt") -######################################################################################################## -# CUDA Kernel -######################################################################################################## - -if USE_CUDA_KERNEL: - - from torch.utils.cpp_extension import load - - load( - name="wkv7", - sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"], - is_python_module=False, - verbose=True, - extra_cuda_cflags=[ - "-res-usage", - "--use_fast_math", - "-O3", - "-Xptxas -O3", - "--extra-device-vectorization", - f"-D_N_={HEAD_SIZE}", - ], - ) - - class WKV_7(torch.autograd.Function): - @staticmethod - def forward(ctx, r, w, k, v, a, b): - with torch.no_grad(): - B, T, C = r.size() - H = C // HEAD_SIZE - N = HEAD_SIZE - assert HEAD_SIZE == C // H - assert r.dtype == DTYPE - assert w.dtype == DTYPE - assert k.dtype == DTYPE - assert v.dtype == DTYPE - assert a.dtype == DTYPE - assert b.dtype == DTYPE - assert r.is_contiguous() - assert w.is_contiguous() - assert k.is_contiguous() - assert v.is_contiguous() - assert a.is_contiguous() - assert b.is_contiguous() - y = torch.empty((B, T, C), device=k.device, dtype=DTYPE, memory_format=torch.contiguous_format) - torch.ops.wkv7.forward(B, T, C, H, r, w, k, v, a, b, y) - return y - - def RWKV7_OP(r, w, k, v, a, b): - return WKV_7.apply(r, w, k, v, a, b) - -else: - - def RWKV7_OP(r, w, k, v, a, b): - B, T, C = r.size() - H = C // HEAD_SIZE - N = HEAD_SIZE - r = r.view(B, T, H, N).float() - k = k.view(B, T, H, N).float() - v = v.view(B, T, H, N).float() - a = a.view(B, T, H, N).float() - b = b.view(B, T, H, N).float() - w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) - out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) - state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float) - - for t in range(T): - kk = k[:, t, :].view(B, H, 1, N) - rr = r[:, t, :].view(B, H, N, 1) - vv = v[:, t, :].view(B, H, N, 1) - aa = a[:, t, :].view(B, H, N, 1) - bb = b[:, t, :].view(B, H, 1, N) - state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk - out[:, t, :] = (state @ rr).view(B, H, N) - - # another method using einsum - # - # kk = k[:, t, :] - # rr = r[:, t, :] - # vv = v[:, t, :] - # aa = a[:, t, :] - # bb = b[:, t, :] - # sab = torch.einsum('bhik,bhk,bhj->bhij', state, aa, bb) - # state = state * w[: , t, :, None, :] + sab + torch.einsum('bhj,bhi->bhij', kk, vv) - # out[:, t, :] = torch.einsum('bhj,bhij->bhi', rr, state) - - return out.view(B, T, C).to(dtype=DTYPE) - ######################################################################################################## # RWKV TimeMix @@ -296,7 +207,32 @@ class RWKV_Tmix_x070(Module): kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) k = k * (1 + (a - 1) * self.k_a) + def RWKV7_OP(r, w, k, v, a, b): + B, T, C = r.size() + H = C // HEAD_SIZE + N = HEAD_SIZE + r = r.view(B, T, H, N).float() + k = k.view(B, T, H, N).float() + v = v.view(B, T, H, N).float() + a = a.view(B, T, H, N).float() + b = b.view(B, T, H, N).float() + w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) + out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) + state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float) + + for t in range(T): + kk = k[:, t, :].view(B, H, 1, N) + rr = r[:, t, :].view(B, H, N, 1) + vv = v[:, t, :].view(B, H, N, 1) + aa = a[:, t, :].view(B, H, N, 1) + bb = b[:, t, :].view(B, H, 1, N) + state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk + out[:, t, :] = (state @ rr).view(B, H, N) + + return out.view(B, T, C).to(dtype=DTYPE) + x = RWKV7_OP(r, w, k, v, -kk, kk * a) + x = self.ln_x(x.view(B * T, C)).view(B, T, C) x = x + (