#include #include "ATen/ATen.h" typedef at::Half bf16; // typedef at::BFloat16 bf16; void cuda_forward(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y); void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) { cuda_forward(B, T, C, H, state.data_ptr(), r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr()); } TORCH_LIBRARY(wkv7s, m) { m.def("forward", forward); }