Spaces:
Running
on
T4
Running
on
T4
sparkleman
commited on
Commit
Β·
3b5e872
1
Parent(s):
aeaf225
UPDATE
Browse files- .gitignore +3 -1
- app.py +10 -8
- pyproject.toml +1 -1
- rwkv/__init__.py +0 -0
- {cuda β rwkv/cuda}/gemm_fp16_cublas.cpp +75 -75
- {cuda β rwkv/cuda}/operators.cu +246 -246
- {cuda β rwkv/cuda}/rwkv5.cu +88 -88
- {cuda β rwkv/cuda}/rwkv5_op.cpp +34 -34
- {cuda β rwkv/cuda}/rwkv6.cu +87 -87
- {cuda β rwkv/cuda}/rwkv6_op.cpp +34 -34
- rwkv/cuda/rwkv7.cu +77 -0
- rwkv/cuda/rwkv7_op.cpp +26 -0
- {cuda β rwkv/cuda}/wrapper.cpp +141 -141
- rwkv/model.py +0 -0
- rwkv/rwkv_tokenizer.py +103 -0
- rwkv/rwkv_vocab_v20230424.txt +0 -0
- rwkv/utils.py +135 -0
.gitignore
CHANGED
@@ -16,4 +16,6 @@ wheels/
|
|
16 |
*.st
|
17 |
*local*
|
18 |
|
19 |
-
dist-frontend/
|
|
|
|
|
|
16 |
*.st
|
17 |
*local*
|
18 |
|
19 |
+
dist-frontend/
|
20 |
+
|
21 |
+
.vscode/
|
app.py
CHANGED
@@ -99,12 +99,6 @@ for model_config in CONFIG.MODELS:
|
|
99 |
)
|
100 |
logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
|
101 |
|
102 |
-
tmp_model = RWKV(
|
103 |
-
model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
|
104 |
-
strategy=CONFIG.STRATEGY,
|
105 |
-
)
|
106 |
-
tmp_pipeline = PIPELINE(tmp_model, model_config.VOCAB)
|
107 |
-
|
108 |
if model_config.DEFAULT_CHAT:
|
109 |
if DEFALUT_MODEL_NAME != None:
|
110 |
logger.info(
|
@@ -123,8 +117,16 @@ for model_config in CONFIG.MODELS:
|
|
123 |
|
124 |
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
125 |
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
126 |
-
MODEL_STORAGE[model_config.SERVICE_NAME].model =
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
logGPUState()
|
129 |
|
130 |
|
|
|
99 |
)
|
100 |
logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
if model_config.DEFAULT_CHAT:
|
103 |
if DEFALUT_MODEL_NAME != None:
|
104 |
logger.info(
|
|
|
117 |
|
118 |
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
119 |
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
120 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
|
121 |
+
model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
|
122 |
+
strategy=CONFIG.STRATEGY,
|
123 |
+
)
|
124 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
|
125 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
|
126 |
+
)
|
127 |
+
if "cuda" in CONFIG.STRATEGY:
|
128 |
+
# torch.cuda.empty_cache()
|
129 |
+
gc.collect()
|
130 |
logGPUState()
|
131 |
|
132 |
|
pyproject.toml
CHANGED
@@ -14,7 +14,7 @@ dependencies = [
|
|
14 |
"pydantic-settings>=2.8.1",
|
15 |
"pynvml>=12.0.0",
|
16 |
"rich>=13.9.4",
|
17 |
-
"rwkv==0.8.28",
|
18 |
"setuptools>=75.8.2",
|
19 |
"snowflake-id>=1.0.2",
|
20 |
]
|
|
|
14 |
"pydantic-settings>=2.8.1",
|
15 |
"pynvml>=12.0.0",
|
16 |
"rich>=13.9.4",
|
17 |
+
# "rwkv==0.8.28",
|
18 |
"setuptools>=75.8.2",
|
19 |
"snowflake-id>=1.0.2",
|
20 |
]
|
rwkv/__init__.py
ADDED
File without changes
|
{cuda β rwkv/cuda}/gemm_fp16_cublas.cpp
RENAMED
@@ -1,75 +1,75 @@
|
|
1 |
-
#include <cublas_v2.h>
|
2 |
-
#include <cuda.h>
|
3 |
-
#include <cuda_fp16.h>
|
4 |
-
#include <cuda_runtime.h>
|
5 |
-
#include <torch/extension.h>
|
6 |
-
#include <c10/cuda/CUDAGuard.h>
|
7 |
-
#include <ATen/cuda/CUDAContext.h>
|
8 |
-
|
9 |
-
#define CUBLAS_CHECK(condition) \
|
10 |
-
for (cublasStatus_t _cublas_check_status = (condition); \
|
11 |
-
_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
|
12 |
-
throw std::runtime_error("cuBLAS error " + \
|
13 |
-
std::to_string(_cublas_check_status) + " at " + \
|
14 |
-
std::to_string(__LINE__));
|
15 |
-
|
16 |
-
#define CUDA_CHECK(condition) \
|
17 |
-
for (cudaError_t _cuda_check_status = (condition); \
|
18 |
-
_cuda_check_status != cudaSuccess;) \
|
19 |
-
throw std::runtime_error( \
|
20 |
-
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
21 |
-
" at " + std::to_string(__LINE__));
|
22 |
-
|
23 |
-
/*
|
24 |
-
NOTE: blas gemm is column-major by default, but we need row-major output.
|
25 |
-
The data of row-major, transposed matrix is exactly the same as the
|
26 |
-
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
27 |
-
*/
|
28 |
-
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
29 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
30 |
-
const auto cuda_data_type = CUDA_R_16F;
|
31 |
-
const auto cuda_c_data_type =
|
32 |
-
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
33 |
-
const auto compute_type = CUDA_R_32F;
|
34 |
-
const float sp_alpha = 1.f;
|
35 |
-
// swap a and b, and use CUBLAS_OP_N. see the notes above
|
36 |
-
std::swap(a, b);
|
37 |
-
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
|
38 |
-
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
|
39 |
-
// m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
|
40 |
-
// negative axis is used because of the existence of batch matmul.
|
41 |
-
const int m = a.size(-1);
|
42 |
-
const int k = a.size(-2);
|
43 |
-
const int n = b.size(-2);
|
44 |
-
const int cublas_lda = m;
|
45 |
-
const int cublas_ldb = k;
|
46 |
-
const int cublas_ldc = m;
|
47 |
-
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
|
48 |
-
|
49 |
-
#if CUDA_VERSION >= 11000
|
50 |
-
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
51 |
-
#else
|
52 |
-
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
53 |
-
#endif
|
54 |
-
const float sp_beta = 0.f;
|
55 |
-
if (a.sizes().size() == 2 && b.sizes().size() == 2) {
|
56 |
-
CUBLAS_CHECK(cublasGemmEx(
|
57 |
-
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
|
58 |
-
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
|
59 |
-
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
|
60 |
-
compute_type, algo));
|
61 |
-
} else {
|
62 |
-
// batch matmul
|
63 |
-
assert(a.sizes().size() == 3 && b.sizes().size() == 3);
|
64 |
-
|
65 |
-
const long long int cublas_stride_a = m * k;
|
66 |
-
const long long int cublas_stride_b = k * n;
|
67 |
-
const long long int cublas_stride_c = m * n;
|
68 |
-
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
|
69 |
-
cublas_handle, cublas_trans_a, cublas_trans_b, m,
|
70 |
-
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
|
71 |
-
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
|
72 |
-
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
|
73 |
-
a.size(0), compute_type, algo));
|
74 |
-
}
|
75 |
-
}
|
|
|
1 |
+
#include <cublas_v2.h>
|
2 |
+
#include <cuda.h>
|
3 |
+
#include <cuda_fp16.h>
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
#include <torch/extension.h>
|
6 |
+
#include <c10/cuda/CUDAGuard.h>
|
7 |
+
#include <ATen/cuda/CUDAContext.h>
|
8 |
+
|
9 |
+
#define CUBLAS_CHECK(condition) \
|
10 |
+
for (cublasStatus_t _cublas_check_status = (condition); \
|
11 |
+
_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
|
12 |
+
throw std::runtime_error("cuBLAS error " + \
|
13 |
+
std::to_string(_cublas_check_status) + " at " + \
|
14 |
+
std::to_string(__LINE__));
|
15 |
+
|
16 |
+
#define CUDA_CHECK(condition) \
|
17 |
+
for (cudaError_t _cuda_check_status = (condition); \
|
18 |
+
_cuda_check_status != cudaSuccess;) \
|
19 |
+
throw std::runtime_error( \
|
20 |
+
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
21 |
+
" at " + std::to_string(__LINE__));
|
22 |
+
|
23 |
+
/*
|
24 |
+
NOTE: blas gemm is column-major by default, but we need row-major output.
|
25 |
+
The data of row-major, transposed matrix is exactly the same as the
|
26 |
+
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
27 |
+
*/
|
28 |
+
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
29 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
30 |
+
const auto cuda_data_type = CUDA_R_16F;
|
31 |
+
const auto cuda_c_data_type =
|
32 |
+
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
33 |
+
const auto compute_type = CUDA_R_32F;
|
34 |
+
const float sp_alpha = 1.f;
|
35 |
+
// swap a and b, and use CUBLAS_OP_N. see the notes above
|
36 |
+
std::swap(a, b);
|
37 |
+
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
|
38 |
+
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
|
39 |
+
// m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
|
40 |
+
// negative axis is used because of the existence of batch matmul.
|
41 |
+
const int m = a.size(-1);
|
42 |
+
const int k = a.size(-2);
|
43 |
+
const int n = b.size(-2);
|
44 |
+
const int cublas_lda = m;
|
45 |
+
const int cublas_ldb = k;
|
46 |
+
const int cublas_ldc = m;
|
47 |
+
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
|
48 |
+
|
49 |
+
#if CUDA_VERSION >= 11000
|
50 |
+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
51 |
+
#else
|
52 |
+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
53 |
+
#endif
|
54 |
+
const float sp_beta = 0.f;
|
55 |
+
if (a.sizes().size() == 2 && b.sizes().size() == 2) {
|
56 |
+
CUBLAS_CHECK(cublasGemmEx(
|
57 |
+
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
|
58 |
+
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
|
59 |
+
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
|
60 |
+
compute_type, algo));
|
61 |
+
} else {
|
62 |
+
// batch matmul
|
63 |
+
assert(a.sizes().size() == 3 && b.sizes().size() == 3);
|
64 |
+
|
65 |
+
const long long int cublas_stride_a = m * k;
|
66 |
+
const long long int cublas_stride_b = k * n;
|
67 |
+
const long long int cublas_stride_c = m * n;
|
68 |
+
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
|
69 |
+
cublas_handle, cublas_trans_a, cublas_trans_b, m,
|
70 |
+
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
|
71 |
+
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
|
72 |
+
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
|
73 |
+
a.size(0), compute_type, algo));
|
74 |
+
}
|
75 |
+
}
|
{cuda β rwkv/cuda}/operators.cu
RENAMED
@@ -1,246 +1,246 @@
|
|
1 |
-
#include <stdio.h>
|
2 |
-
#include <assert.h>
|
3 |
-
#include "ATen/ATen.h"
|
4 |
-
#include <cuda_fp16.h>
|
5 |
-
#define MIN_VALUE (-1e38)
|
6 |
-
typedef at::Half fp16;
|
7 |
-
__half *cast(fp16 *ptr) {
|
8 |
-
return reinterpret_cast<__half *>(ptr);
|
9 |
-
}
|
10 |
-
|
11 |
-
template <typename F>
|
12 |
-
__global__ void kernel_wkv_forward(const int B, const int T, const int C,
|
13 |
-
const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
14 |
-
F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
|
15 |
-
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
16 |
-
const int _b = idx / C;
|
17 |
-
const int _c = idx % C;
|
18 |
-
const int _offset = _b * T * C + _c;
|
19 |
-
const int _state_offset = _b * C + _c;
|
20 |
-
|
21 |
-
float u = _u[_c];
|
22 |
-
float w = _w[_c];
|
23 |
-
const F *__restrict__ const k = _k + _offset;
|
24 |
-
const F *__restrict__ const v = _v + _offset;
|
25 |
-
F *__restrict__ const y = _y + _offset;
|
26 |
-
|
27 |
-
float aa = _aa[_state_offset];
|
28 |
-
float bb = _bb[_state_offset];
|
29 |
-
float pp = _pp[_state_offset];
|
30 |
-
for (int i = 0; i < T; i++) {
|
31 |
-
const int ii = i * C;
|
32 |
-
const float kk = float(k[ii]);
|
33 |
-
const float vv = float(v[ii]);
|
34 |
-
float ww = u + kk;
|
35 |
-
float p = max(pp, ww);
|
36 |
-
float e1 = exp(pp - p);
|
37 |
-
float e2 = exp(ww - p);
|
38 |
-
y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
39 |
-
ww = w + pp;
|
40 |
-
p = max(ww, kk);
|
41 |
-
e1 = exp(ww - p);
|
42 |
-
e2 = exp(kk - p);
|
43 |
-
aa = e1 * aa + e2 * vv;
|
44 |
-
bb = e1 * bb + e2;
|
45 |
-
pp = p;
|
46 |
-
}
|
47 |
-
_aa[_state_offset] = aa;
|
48 |
-
_bb[_state_offset] = bb;
|
49 |
-
_pp[_state_offset] = pp;
|
50 |
-
}
|
51 |
-
|
52 |
-
template <typename F>
|
53 |
-
void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
|
54 |
-
dim3 threadsPerBlock( min(C, 32) );
|
55 |
-
assert(B * C % threadsPerBlock.x == 0);
|
56 |
-
dim3 numBlocks(B * C / threadsPerBlock.x);
|
57 |
-
kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
|
58 |
-
}
|
59 |
-
|
60 |
-
template void cuda_wkv_forward<fp16>(
|
61 |
-
int B, int T, int C,
|
62 |
-
float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
|
63 |
-
float *aa, float *bb, float *pp);
|
64 |
-
template void cuda_wkv_forward<float>(
|
65 |
-
int B, int T, int C,
|
66 |
-
float *w, float *u, float *k, float *v, float *y,
|
67 |
-
float *aa, float *bb, float *pp);
|
68 |
-
|
69 |
-
__global__ void kernel_mm_seq_fp32i8(
|
70 |
-
const int B, const int N, const int M,
|
71 |
-
const float *__restrict__ const x, const int x_stride,
|
72 |
-
const uint8_t *__restrict__ const w, const int w_stride,
|
73 |
-
const float *__restrict__ const mx,
|
74 |
-
const float *__restrict__ const rx,
|
75 |
-
const float *__restrict__ const my,
|
76 |
-
const float *__restrict__ const ry,
|
77 |
-
float *__restrict__ const y, const int y_stride) {
|
78 |
-
|
79 |
-
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
80 |
-
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
81 |
-
|
82 |
-
if (i < B && k < M) {
|
83 |
-
float y_local = 0;
|
84 |
-
for (int j = 0; j < N; ++j) {
|
85 |
-
y_local += x[i * x_stride + j] * (
|
86 |
-
(float(w[j * w_stride + k]) + 0.5f)
|
87 |
-
* rx[k] * ry[j] + mx[k] + my[j]
|
88 |
-
);
|
89 |
-
}
|
90 |
-
y[i * y_stride + k] = y_local;
|
91 |
-
}
|
92 |
-
}
|
93 |
-
|
94 |
-
template <typename F>
|
95 |
-
void cuda_mm8_seq(int B, int N, int M,
|
96 |
-
F *x, int x_stride,
|
97 |
-
uint8_t *w, int w_stride,
|
98 |
-
F *mx, F *rx,
|
99 |
-
F *my, F *ry,
|
100 |
-
F *y, int y_stride);
|
101 |
-
|
102 |
-
template <>
|
103 |
-
void cuda_mm8_seq<float>(int B, int N, int M,
|
104 |
-
float *x, int x_stride,
|
105 |
-
uint8_t *w, int w_stride,
|
106 |
-
float *mx, float *rx,
|
107 |
-
float *my, float *ry,
|
108 |
-
float *y, int y_stride) {
|
109 |
-
dim3 blockSize(1, 128);
|
110 |
-
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
111 |
-
kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
|
112 |
-
B, N, M, x, x_stride, w, w_stride,
|
113 |
-
mx, rx, my, ry, y, y_stride);
|
114 |
-
}
|
115 |
-
|
116 |
-
__global__ void kernel_mm_seq_fp16i8(
|
117 |
-
const int B, const int N, const int M,
|
118 |
-
const __half *__restrict__ const x, const int x_stride,
|
119 |
-
const uint8_t *__restrict__ const w, const int w_stride,
|
120 |
-
const __half *__restrict__ const mx,
|
121 |
-
const __half *__restrict__ const rx,
|
122 |
-
const __half *__restrict__ const my,
|
123 |
-
const __half *__restrict__ const ry,
|
124 |
-
__half *__restrict__ const y, const int y_stride) {
|
125 |
-
|
126 |
-
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
127 |
-
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
128 |
-
|
129 |
-
if (i < B && k < M) {
|
130 |
-
float y_local = 0;
|
131 |
-
for (int j = 0; j < N; ++j) {
|
132 |
-
y_local += __half2float(x[i * x_stride + j]) * (
|
133 |
-
(float(w[j * w_stride + k]) + 0.5f)
|
134 |
-
* __half2float(rx[k]) * __half2float(ry[j])
|
135 |
-
+ __half2float(mx[k]) + __half2float(my[j])
|
136 |
-
);
|
137 |
-
}
|
138 |
-
y[i * y_stride + k] = __float2half(y_local);
|
139 |
-
}
|
140 |
-
}
|
141 |
-
|
142 |
-
template <>
|
143 |
-
void cuda_mm8_seq<fp16>(int B, int N, int M,
|
144 |
-
fp16 *x, int x_stride,
|
145 |
-
uint8_t *w, int w_stride,
|
146 |
-
fp16 *mx, fp16 *rx,
|
147 |
-
fp16 *my, fp16 *ry,
|
148 |
-
fp16 *y, int y_stride) {
|
149 |
-
dim3 blockSize(1, 128);
|
150 |
-
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
151 |
-
kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
|
152 |
-
B, N, M, cast(x), x_stride, w, w_stride,
|
153 |
-
cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
|
154 |
-
}
|
155 |
-
|
156 |
-
#define MM8_ONE_JSPLIT 24
|
157 |
-
#define MM8_ONE_TILE 1024
|
158 |
-
|
159 |
-
__global__ void kernel_mm_one_fp32i8(
|
160 |
-
const int N, const int M,
|
161 |
-
const float *__restrict__ const x,
|
162 |
-
const uint8_t *__restrict__ const w, const int w_stride,
|
163 |
-
const float *__restrict__ const mx,
|
164 |
-
const float *__restrict__ const rx,
|
165 |
-
const float *__restrict__ const my,
|
166 |
-
const float *__restrict__ const ry,
|
167 |
-
float *__restrict__ const y) {
|
168 |
-
|
169 |
-
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
170 |
-
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
171 |
-
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
172 |
-
|
173 |
-
if (k < M) {
|
174 |
-
float y_local = 0;
|
175 |
-
for (int j = j0; j < j1; ++j) {
|
176 |
-
y_local += x[j] * (
|
177 |
-
(float(w[j * w_stride + k]) + 0.5f)
|
178 |
-
* rx[k] * ry[j] + mx[k] + my[j]
|
179 |
-
);
|
180 |
-
}
|
181 |
-
atomicAdd(&y[k], y_local);
|
182 |
-
}
|
183 |
-
}
|
184 |
-
|
185 |
-
template <typename F>
|
186 |
-
void cuda_mm8_one(int N, int M,
|
187 |
-
F *x,
|
188 |
-
uint8_t *w, int w_stride,
|
189 |
-
F *mx, F *rx,
|
190 |
-
F *my, F *ry,
|
191 |
-
float *y);
|
192 |
-
|
193 |
-
template <>
|
194 |
-
void cuda_mm8_one<float>(int N, int M,
|
195 |
-
float *x,
|
196 |
-
uint8_t *w, int w_stride,
|
197 |
-
float *mx, float *rx,
|
198 |
-
float *my, float *ry,
|
199 |
-
float *y) {
|
200 |
-
dim3 blockSize(1, MM8_ONE_TILE);
|
201 |
-
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
202 |
-
kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
|
203 |
-
N, M, x, w, w_stride,
|
204 |
-
mx, rx, my, ry, y);
|
205 |
-
}
|
206 |
-
|
207 |
-
__global__ void kernel_mm_one_fp16i8(
|
208 |
-
const int N, const int M,
|
209 |
-
const __half *__restrict__ const x,
|
210 |
-
const uint8_t *__restrict__ const w, const int w_stride,
|
211 |
-
const __half *__restrict__ const mx,
|
212 |
-
const __half *__restrict__ const rx,
|
213 |
-
const __half *__restrict__ const my,
|
214 |
-
const __half *__restrict__ const ry,
|
215 |
-
float *__restrict__ const y) {
|
216 |
-
|
217 |
-
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
218 |
-
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
219 |
-
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
220 |
-
|
221 |
-
if (k < M) {
|
222 |
-
float y_local = 0;
|
223 |
-
for (int j = j0; j < j1; ++j) {
|
224 |
-
y_local += __half2float(x[j]) * (
|
225 |
-
(float(w[j * w_stride + k]) + 0.5f)
|
226 |
-
* __half2float(rx[k]) * __half2float(ry[j])
|
227 |
-
+ __half2float(mx[k]) + __half2float(my[j])
|
228 |
-
);
|
229 |
-
}
|
230 |
-
atomicAdd(&y[k], y_local);
|
231 |
-
}
|
232 |
-
}
|
233 |
-
|
234 |
-
template <>
|
235 |
-
void cuda_mm8_one<fp16>(int N, int M,
|
236 |
-
fp16 *x,
|
237 |
-
uint8_t *w, int w_stride,
|
238 |
-
fp16 *mx, fp16 *rx,
|
239 |
-
fp16 *my, fp16 *ry,
|
240 |
-
float *y) {
|
241 |
-
dim3 blockSize(1, MM8_ONE_TILE);
|
242 |
-
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
243 |
-
kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
|
244 |
-
N, M, cast(x), w, w_stride,
|
245 |
-
cast(mx), cast(rx), cast(my), cast(ry), y);
|
246 |
-
}
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
#include "ATen/ATen.h"
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
#define MIN_VALUE (-1e38)
|
6 |
+
typedef at::Half fp16;
|
7 |
+
__half *cast(fp16 *ptr) {
|
8 |
+
return reinterpret_cast<__half *>(ptr);
|
9 |
+
}
|
10 |
+
|
11 |
+
template <typename F>
|
12 |
+
__global__ void kernel_wkv_forward(const int B, const int T, const int C,
|
13 |
+
const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
14 |
+
F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
|
15 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
16 |
+
const int _b = idx / C;
|
17 |
+
const int _c = idx % C;
|
18 |
+
const int _offset = _b * T * C + _c;
|
19 |
+
const int _state_offset = _b * C + _c;
|
20 |
+
|
21 |
+
float u = _u[_c];
|
22 |
+
float w = _w[_c];
|
23 |
+
const F *__restrict__ const k = _k + _offset;
|
24 |
+
const F *__restrict__ const v = _v + _offset;
|
25 |
+
F *__restrict__ const y = _y + _offset;
|
26 |
+
|
27 |
+
float aa = _aa[_state_offset];
|
28 |
+
float bb = _bb[_state_offset];
|
29 |
+
float pp = _pp[_state_offset];
|
30 |
+
for (int i = 0; i < T; i++) {
|
31 |
+
const int ii = i * C;
|
32 |
+
const float kk = float(k[ii]);
|
33 |
+
const float vv = float(v[ii]);
|
34 |
+
float ww = u + kk;
|
35 |
+
float p = max(pp, ww);
|
36 |
+
float e1 = exp(pp - p);
|
37 |
+
float e2 = exp(ww - p);
|
38 |
+
y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
39 |
+
ww = w + pp;
|
40 |
+
p = max(ww, kk);
|
41 |
+
e1 = exp(ww - p);
|
42 |
+
e2 = exp(kk - p);
|
43 |
+
aa = e1 * aa + e2 * vv;
|
44 |
+
bb = e1 * bb + e2;
|
45 |
+
pp = p;
|
46 |
+
}
|
47 |
+
_aa[_state_offset] = aa;
|
48 |
+
_bb[_state_offset] = bb;
|
49 |
+
_pp[_state_offset] = pp;
|
50 |
+
}
|
51 |
+
|
52 |
+
template <typename F>
|
53 |
+
void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
|
54 |
+
dim3 threadsPerBlock( min(C, 32) );
|
55 |
+
assert(B * C % threadsPerBlock.x == 0);
|
56 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
57 |
+
kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
|
58 |
+
}
|
59 |
+
|
60 |
+
template void cuda_wkv_forward<fp16>(
|
61 |
+
int B, int T, int C,
|
62 |
+
float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
|
63 |
+
float *aa, float *bb, float *pp);
|
64 |
+
template void cuda_wkv_forward<float>(
|
65 |
+
int B, int T, int C,
|
66 |
+
float *w, float *u, float *k, float *v, float *y,
|
67 |
+
float *aa, float *bb, float *pp);
|
68 |
+
|
69 |
+
__global__ void kernel_mm_seq_fp32i8(
|
70 |
+
const int B, const int N, const int M,
|
71 |
+
const float *__restrict__ const x, const int x_stride,
|
72 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
73 |
+
const float *__restrict__ const mx,
|
74 |
+
const float *__restrict__ const rx,
|
75 |
+
const float *__restrict__ const my,
|
76 |
+
const float *__restrict__ const ry,
|
77 |
+
float *__restrict__ const y, const int y_stride) {
|
78 |
+
|
79 |
+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
80 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
81 |
+
|
82 |
+
if (i < B && k < M) {
|
83 |
+
float y_local = 0;
|
84 |
+
for (int j = 0; j < N; ++j) {
|
85 |
+
y_local += x[i * x_stride + j] * (
|
86 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
87 |
+
* rx[k] * ry[j] + mx[k] + my[j]
|
88 |
+
);
|
89 |
+
}
|
90 |
+
y[i * y_stride + k] = y_local;
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
template <typename F>
|
95 |
+
void cuda_mm8_seq(int B, int N, int M,
|
96 |
+
F *x, int x_stride,
|
97 |
+
uint8_t *w, int w_stride,
|
98 |
+
F *mx, F *rx,
|
99 |
+
F *my, F *ry,
|
100 |
+
F *y, int y_stride);
|
101 |
+
|
102 |
+
template <>
|
103 |
+
void cuda_mm8_seq<float>(int B, int N, int M,
|
104 |
+
float *x, int x_stride,
|
105 |
+
uint8_t *w, int w_stride,
|
106 |
+
float *mx, float *rx,
|
107 |
+
float *my, float *ry,
|
108 |
+
float *y, int y_stride) {
|
109 |
+
dim3 blockSize(1, 128);
|
110 |
+
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
111 |
+
kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
|
112 |
+
B, N, M, x, x_stride, w, w_stride,
|
113 |
+
mx, rx, my, ry, y, y_stride);
|
114 |
+
}
|
115 |
+
|
116 |
+
__global__ void kernel_mm_seq_fp16i8(
|
117 |
+
const int B, const int N, const int M,
|
118 |
+
const __half *__restrict__ const x, const int x_stride,
|
119 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
120 |
+
const __half *__restrict__ const mx,
|
121 |
+
const __half *__restrict__ const rx,
|
122 |
+
const __half *__restrict__ const my,
|
123 |
+
const __half *__restrict__ const ry,
|
124 |
+
__half *__restrict__ const y, const int y_stride) {
|
125 |
+
|
126 |
+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
127 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
128 |
+
|
129 |
+
if (i < B && k < M) {
|
130 |
+
float y_local = 0;
|
131 |
+
for (int j = 0; j < N; ++j) {
|
132 |
+
y_local += __half2float(x[i * x_stride + j]) * (
|
133 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
134 |
+
* __half2float(rx[k]) * __half2float(ry[j])
|
135 |
+
+ __half2float(mx[k]) + __half2float(my[j])
|
136 |
+
);
|
137 |
+
}
|
138 |
+
y[i * y_stride + k] = __float2half(y_local);
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
template <>
|
143 |
+
void cuda_mm8_seq<fp16>(int B, int N, int M,
|
144 |
+
fp16 *x, int x_stride,
|
145 |
+
uint8_t *w, int w_stride,
|
146 |
+
fp16 *mx, fp16 *rx,
|
147 |
+
fp16 *my, fp16 *ry,
|
148 |
+
fp16 *y, int y_stride) {
|
149 |
+
dim3 blockSize(1, 128);
|
150 |
+
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
151 |
+
kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
|
152 |
+
B, N, M, cast(x), x_stride, w, w_stride,
|
153 |
+
cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
|
154 |
+
}
|
155 |
+
|
156 |
+
#define MM8_ONE_JSPLIT 24
|
157 |
+
#define MM8_ONE_TILE 1024
|
158 |
+
|
159 |
+
__global__ void kernel_mm_one_fp32i8(
|
160 |
+
const int N, const int M,
|
161 |
+
const float *__restrict__ const x,
|
162 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
163 |
+
const float *__restrict__ const mx,
|
164 |
+
const float *__restrict__ const rx,
|
165 |
+
const float *__restrict__ const my,
|
166 |
+
const float *__restrict__ const ry,
|
167 |
+
float *__restrict__ const y) {
|
168 |
+
|
169 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
170 |
+
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
171 |
+
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
172 |
+
|
173 |
+
if (k < M) {
|
174 |
+
float y_local = 0;
|
175 |
+
for (int j = j0; j < j1; ++j) {
|
176 |
+
y_local += x[j] * (
|
177 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
178 |
+
* rx[k] * ry[j] + mx[k] + my[j]
|
179 |
+
);
|
180 |
+
}
|
181 |
+
atomicAdd(&y[k], y_local);
|
182 |
+
}
|
183 |
+
}
|
184 |
+
|
185 |
+
template <typename F>
|
186 |
+
void cuda_mm8_one(int N, int M,
|
187 |
+
F *x,
|
188 |
+
uint8_t *w, int w_stride,
|
189 |
+
F *mx, F *rx,
|
190 |
+
F *my, F *ry,
|
191 |
+
float *y);
|
192 |
+
|
193 |
+
template <>
|
194 |
+
void cuda_mm8_one<float>(int N, int M,
|
195 |
+
float *x,
|
196 |
+
uint8_t *w, int w_stride,
|
197 |
+
float *mx, float *rx,
|
198 |
+
float *my, float *ry,
|
199 |
+
float *y) {
|
200 |
+
dim3 blockSize(1, MM8_ONE_TILE);
|
201 |
+
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
202 |
+
kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
|
203 |
+
N, M, x, w, w_stride,
|
204 |
+
mx, rx, my, ry, y);
|
205 |
+
}
|
206 |
+
|
207 |
+
__global__ void kernel_mm_one_fp16i8(
|
208 |
+
const int N, const int M,
|
209 |
+
const __half *__restrict__ const x,
|
210 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
211 |
+
const __half *__restrict__ const mx,
|
212 |
+
const __half *__restrict__ const rx,
|
213 |
+
const __half *__restrict__ const my,
|
214 |
+
const __half *__restrict__ const ry,
|
215 |
+
float *__restrict__ const y) {
|
216 |
+
|
217 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
218 |
+
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
219 |
+
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
220 |
+
|
221 |
+
if (k < M) {
|
222 |
+
float y_local = 0;
|
223 |
+
for (int j = j0; j < j1; ++j) {
|
224 |
+
y_local += __half2float(x[j]) * (
|
225 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
226 |
+
* __half2float(rx[k]) * __half2float(ry[j])
|
227 |
+
+ __half2float(mx[k]) + __half2float(my[j])
|
228 |
+
);
|
229 |
+
}
|
230 |
+
atomicAdd(&y[k], y_local);
|
231 |
+
}
|
232 |
+
}
|
233 |
+
|
234 |
+
template <>
|
235 |
+
void cuda_mm8_one<fp16>(int N, int M,
|
236 |
+
fp16 *x,
|
237 |
+
uint8_t *w, int w_stride,
|
238 |
+
fp16 *mx, fp16 *rx,
|
239 |
+
fp16 *my, fp16 *ry,
|
240 |
+
float *y) {
|
241 |
+
dim3 blockSize(1, MM8_ONE_TILE);
|
242 |
+
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
243 |
+
kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
|
244 |
+
N, M, cast(x), w, w_stride,
|
245 |
+
cast(mx), cast(rx), cast(my), cast(ry), y);
|
246 |
+
}
|
{cuda β rwkv/cuda}/rwkv5.cu
RENAMED
@@ -1,88 +1,88 @@
|
|
1 |
-
#include <stdio.h>
|
2 |
-
#include <assert.h>
|
3 |
-
#include "ATen/ATen.h"
|
4 |
-
typedef at::BFloat16 bf16;
|
5 |
-
typedef at::Half fp16;
|
6 |
-
typedef float fp32;
|
7 |
-
|
8 |
-
template <typename F>
|
9 |
-
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
10 |
-
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
11 |
-
F *__restrict__ const _y)
|
12 |
-
{
|
13 |
-
const int b = blockIdx.x / H;
|
14 |
-
const int h = blockIdx.x % H;
|
15 |
-
const int i = threadIdx.x;
|
16 |
-
_w += h*_N_;
|
17 |
-
_u += h*_N_;
|
18 |
-
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
19 |
-
|
20 |
-
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
21 |
-
|
22 |
-
float state[_N_];
|
23 |
-
#pragma unroll
|
24 |
-
for (int j = 0; j < _N_; j++)
|
25 |
-
state[j] = _state[j];
|
26 |
-
|
27 |
-
__syncthreads();
|
28 |
-
u[i] = float(_u[i]);
|
29 |
-
w[i] = _w[i];
|
30 |
-
__syncthreads();
|
31 |
-
|
32 |
-
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
33 |
-
{
|
34 |
-
__syncthreads();
|
35 |
-
r[i] = float(_r[t]);
|
36 |
-
k[i] = float(_k[t]);
|
37 |
-
__syncthreads();
|
38 |
-
|
39 |
-
const float v = float(_v[t]);
|
40 |
-
float y = 0;
|
41 |
-
|
42 |
-
#pragma unroll
|
43 |
-
for (int j = 0; j < _N_; j+=4)
|
44 |
-
{
|
45 |
-
const float4& r_ = (float4&)(r[j]);
|
46 |
-
const float4& k_ = (float4&)(k[j]);
|
47 |
-
const float4& w_ = (float4&)(w[j]);
|
48 |
-
const float4& u_ = (float4&)(u[j]);
|
49 |
-
float4& s = (float4&)(state[j]);
|
50 |
-
float4 x;
|
51 |
-
|
52 |
-
x.x = k_.x * v;
|
53 |
-
x.y = k_.y * v;
|
54 |
-
x.z = k_.z * v;
|
55 |
-
x.w = k_.w * v;
|
56 |
-
|
57 |
-
y += r_.x * (u_.x * x.x + s.x);
|
58 |
-
y += r_.y * (u_.y * x.y + s.y);
|
59 |
-
y += r_.z * (u_.z * x.z + s.z);
|
60 |
-
y += r_.w * (u_.w * x.w + s.w);
|
61 |
-
|
62 |
-
s.x = s.x * w_.x + x.x;
|
63 |
-
s.y = s.y * w_.y + x.y;
|
64 |
-
s.z = s.z * w_.z + x.z;
|
65 |
-
s.w = s.w * w_.w + x.w;
|
66 |
-
}
|
67 |
-
_y[t] = F(y);
|
68 |
-
}
|
69 |
-
#pragma unroll
|
70 |
-
for (int j = 0; j < _N_; j++)
|
71 |
-
_state[j] = state[j];
|
72 |
-
}
|
73 |
-
|
74 |
-
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
75 |
-
{
|
76 |
-
assert(H*_N_ == C);
|
77 |
-
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
78 |
-
}
|
79 |
-
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
80 |
-
{
|
81 |
-
assert(H*_N_ == C);
|
82 |
-
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
83 |
-
}
|
84 |
-
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
85 |
-
{
|
86 |
-
assert(H*_N_ == C);
|
87 |
-
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
88 |
-
}
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
#include "ATen/ATen.h"
|
4 |
+
typedef at::BFloat16 bf16;
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
template <typename F>
|
9 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
10 |
+
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
11 |
+
F *__restrict__ const _y)
|
12 |
+
{
|
13 |
+
const int b = blockIdx.x / H;
|
14 |
+
const int h = blockIdx.x % H;
|
15 |
+
const int i = threadIdx.x;
|
16 |
+
_w += h*_N_;
|
17 |
+
_u += h*_N_;
|
18 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
19 |
+
|
20 |
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
21 |
+
|
22 |
+
float state[_N_];
|
23 |
+
#pragma unroll
|
24 |
+
for (int j = 0; j < _N_; j++)
|
25 |
+
state[j] = _state[j];
|
26 |
+
|
27 |
+
__syncthreads();
|
28 |
+
u[i] = float(_u[i]);
|
29 |
+
w[i] = _w[i];
|
30 |
+
__syncthreads();
|
31 |
+
|
32 |
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
33 |
+
{
|
34 |
+
__syncthreads();
|
35 |
+
r[i] = float(_r[t]);
|
36 |
+
k[i] = float(_k[t]);
|
37 |
+
__syncthreads();
|
38 |
+
|
39 |
+
const float v = float(_v[t]);
|
40 |
+
float y = 0;
|
41 |
+
|
42 |
+
#pragma unroll
|
43 |
+
for (int j = 0; j < _N_; j+=4)
|
44 |
+
{
|
45 |
+
const float4& r_ = (float4&)(r[j]);
|
46 |
+
const float4& k_ = (float4&)(k[j]);
|
47 |
+
const float4& w_ = (float4&)(w[j]);
|
48 |
+
const float4& u_ = (float4&)(u[j]);
|
49 |
+
float4& s = (float4&)(state[j]);
|
50 |
+
float4 x;
|
51 |
+
|
52 |
+
x.x = k_.x * v;
|
53 |
+
x.y = k_.y * v;
|
54 |
+
x.z = k_.z * v;
|
55 |
+
x.w = k_.w * v;
|
56 |
+
|
57 |
+
y += r_.x * (u_.x * x.x + s.x);
|
58 |
+
y += r_.y * (u_.y * x.y + s.y);
|
59 |
+
y += r_.z * (u_.z * x.z + s.z);
|
60 |
+
y += r_.w * (u_.w * x.w + s.w);
|
61 |
+
|
62 |
+
s.x = s.x * w_.x + x.x;
|
63 |
+
s.y = s.y * w_.y + x.y;
|
64 |
+
s.z = s.z * w_.z + x.z;
|
65 |
+
s.w = s.w * w_.w + x.w;
|
66 |
+
}
|
67 |
+
_y[t] = F(y);
|
68 |
+
}
|
69 |
+
#pragma unroll
|
70 |
+
for (int j = 0; j < _N_; j++)
|
71 |
+
_state[j] = state[j];
|
72 |
+
}
|
73 |
+
|
74 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
75 |
+
{
|
76 |
+
assert(H*_N_ == C);
|
77 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
78 |
+
}
|
79 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
80 |
+
{
|
81 |
+
assert(H*_N_ == C);
|
82 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
83 |
+
}
|
84 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
85 |
+
{
|
86 |
+
assert(H*_N_ == C);
|
87 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
88 |
+
}
|
{cuda β rwkv/cuda}/rwkv5_op.cpp
RENAMED
@@ -1,34 +1,34 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
#include "ATen/ATen.h"
|
3 |
-
#include <c10/cuda/CUDAGuard.h>
|
4 |
-
typedef at::BFloat16 bf16;
|
5 |
-
typedef at::Half fp16;
|
6 |
-
typedef float fp32;
|
7 |
-
|
8 |
-
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
9 |
-
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
10 |
-
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
11 |
-
|
12 |
-
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
13 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
14 |
-
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
15 |
-
}
|
16 |
-
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
17 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
18 |
-
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
19 |
-
}
|
20 |
-
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
21 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
22 |
-
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
23 |
-
}
|
24 |
-
|
25 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
26 |
-
m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
|
27 |
-
m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
|
28 |
-
m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
|
29 |
-
}
|
30 |
-
TORCH_LIBRARY(rwkv5, m) {
|
31 |
-
m.def("forward_bf16", forward_bf16);
|
32 |
-
m.def("forward_fp16", forward_fp16);
|
33 |
-
m.def("forward_fp32", forward_fp32);
|
34 |
-
}
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "ATen/ATen.h"
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
typedef at::BFloat16 bf16;
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
11 |
+
|
12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
13 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
14 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
15 |
+
}
|
16 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
17 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
18 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
19 |
+
}
|
20 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
21 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
22 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
23 |
+
}
|
24 |
+
|
25 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
26 |
+
m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
|
27 |
+
m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
|
28 |
+
m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
|
29 |
+
}
|
30 |
+
TORCH_LIBRARY(rwkv5, m) {
|
31 |
+
m.def("forward_bf16", forward_bf16);
|
32 |
+
m.def("forward_fp16", forward_fp16);
|
33 |
+
m.def("forward_fp32", forward_fp32);
|
34 |
+
}
|
{cuda β rwkv/cuda}/rwkv6.cu
RENAMED
@@ -1,87 +1,87 @@
|
|
1 |
-
#include <stdio.h>
|
2 |
-
#include <assert.h>
|
3 |
-
#include "ATen/ATen.h"
|
4 |
-
typedef at::BFloat16 bf16;
|
5 |
-
typedef at::Half fp16;
|
6 |
-
typedef float fp32;
|
7 |
-
|
8 |
-
template <typename F>
|
9 |
-
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
10 |
-
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
11 |
-
F *__restrict__ const _y)
|
12 |
-
{
|
13 |
-
const int b = blockIdx.x / H;
|
14 |
-
const int h = blockIdx.x % H;
|
15 |
-
const int i = threadIdx.x;
|
16 |
-
_u += h*_N_;
|
17 |
-
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
18 |
-
|
19 |
-
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
20 |
-
|
21 |
-
float state[_N_];
|
22 |
-
#pragma unroll
|
23 |
-
for (int j = 0; j < _N_; j++)
|
24 |
-
state[j] = _state[j];
|
25 |
-
|
26 |
-
__syncthreads();
|
27 |
-
u[i] = float(_u[i]);
|
28 |
-
__syncthreads();
|
29 |
-
|
30 |
-
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
31 |
-
{
|
32 |
-
__syncthreads();
|
33 |
-
w[i] = _w[t];
|
34 |
-
r[i] = float(_r[t]);
|
35 |
-
k[i] = float(_k[t]);
|
36 |
-
__syncthreads();
|
37 |
-
|
38 |
-
const float v = float(_v[t]);
|
39 |
-
float y = 0;
|
40 |
-
|
41 |
-
#pragma unroll
|
42 |
-
for (int j = 0; j < _N_; j+=4)
|
43 |
-
{
|
44 |
-
const float4& r_ = (float4&)(r[j]);
|
45 |
-
const float4& k_ = (float4&)(k[j]);
|
46 |
-
const float4& w_ = (float4&)(w[j]);
|
47 |
-
const float4& u_ = (float4&)(u[j]);
|
48 |
-
float4& s = (float4&)(state[j]);
|
49 |
-
float4 x;
|
50 |
-
|
51 |
-
x.x = k_.x * v;
|
52 |
-
x.y = k_.y * v;
|
53 |
-
x.z = k_.z * v;
|
54 |
-
x.w = k_.w * v;
|
55 |
-
|
56 |
-
y += r_.x * (u_.x * x.x + s.x);
|
57 |
-
y += r_.y * (u_.y * x.y + s.y);
|
58 |
-
y += r_.z * (u_.z * x.z + s.z);
|
59 |
-
y += r_.w * (u_.w * x.w + s.w);
|
60 |
-
|
61 |
-
s.x = s.x * w_.x + x.x;
|
62 |
-
s.y = s.y * w_.y + x.y;
|
63 |
-
s.z = s.z * w_.z + x.z;
|
64 |
-
s.w = s.w * w_.w + x.w;
|
65 |
-
}
|
66 |
-
_y[t] = F(y);
|
67 |
-
}
|
68 |
-
#pragma unroll
|
69 |
-
for (int j = 0; j < _N_; j++)
|
70 |
-
_state[j] = state[j];
|
71 |
-
}
|
72 |
-
|
73 |
-
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
74 |
-
{
|
75 |
-
assert(H*_N_ == C);
|
76 |
-
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
77 |
-
}
|
78 |
-
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
79 |
-
{
|
80 |
-
assert(H*_N_ == C);
|
81 |
-
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
82 |
-
}
|
83 |
-
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
84 |
-
{
|
85 |
-
assert(H*_N_ == C);
|
86 |
-
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
87 |
-
}
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
#include "ATen/ATen.h"
|
4 |
+
typedef at::BFloat16 bf16;
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
template <typename F>
|
9 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
10 |
+
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
11 |
+
F *__restrict__ const _y)
|
12 |
+
{
|
13 |
+
const int b = blockIdx.x / H;
|
14 |
+
const int h = blockIdx.x % H;
|
15 |
+
const int i = threadIdx.x;
|
16 |
+
_u += h*_N_;
|
17 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
18 |
+
|
19 |
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
20 |
+
|
21 |
+
float state[_N_];
|
22 |
+
#pragma unroll
|
23 |
+
for (int j = 0; j < _N_; j++)
|
24 |
+
state[j] = _state[j];
|
25 |
+
|
26 |
+
__syncthreads();
|
27 |
+
u[i] = float(_u[i]);
|
28 |
+
__syncthreads();
|
29 |
+
|
30 |
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
31 |
+
{
|
32 |
+
__syncthreads();
|
33 |
+
w[i] = _w[t];
|
34 |
+
r[i] = float(_r[t]);
|
35 |
+
k[i] = float(_k[t]);
|
36 |
+
__syncthreads();
|
37 |
+
|
38 |
+
const float v = float(_v[t]);
|
39 |
+
float y = 0;
|
40 |
+
|
41 |
+
#pragma unroll
|
42 |
+
for (int j = 0; j < _N_; j+=4)
|
43 |
+
{
|
44 |
+
const float4& r_ = (float4&)(r[j]);
|
45 |
+
const float4& k_ = (float4&)(k[j]);
|
46 |
+
const float4& w_ = (float4&)(w[j]);
|
47 |
+
const float4& u_ = (float4&)(u[j]);
|
48 |
+
float4& s = (float4&)(state[j]);
|
49 |
+
float4 x;
|
50 |
+
|
51 |
+
x.x = k_.x * v;
|
52 |
+
x.y = k_.y * v;
|
53 |
+
x.z = k_.z * v;
|
54 |
+
x.w = k_.w * v;
|
55 |
+
|
56 |
+
y += r_.x * (u_.x * x.x + s.x);
|
57 |
+
y += r_.y * (u_.y * x.y + s.y);
|
58 |
+
y += r_.z * (u_.z * x.z + s.z);
|
59 |
+
y += r_.w * (u_.w * x.w + s.w);
|
60 |
+
|
61 |
+
s.x = s.x * w_.x + x.x;
|
62 |
+
s.y = s.y * w_.y + x.y;
|
63 |
+
s.z = s.z * w_.z + x.z;
|
64 |
+
s.w = s.w * w_.w + x.w;
|
65 |
+
}
|
66 |
+
_y[t] = F(y);
|
67 |
+
}
|
68 |
+
#pragma unroll
|
69 |
+
for (int j = 0; j < _N_; j++)
|
70 |
+
_state[j] = state[j];
|
71 |
+
}
|
72 |
+
|
73 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
74 |
+
{
|
75 |
+
assert(H*_N_ == C);
|
76 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
77 |
+
}
|
78 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
79 |
+
{
|
80 |
+
assert(H*_N_ == C);
|
81 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
82 |
+
}
|
83 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
84 |
+
{
|
85 |
+
assert(H*_N_ == C);
|
86 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
87 |
+
}
|
{cuda β rwkv/cuda}/rwkv6_op.cpp
RENAMED
@@ -1,34 +1,34 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
#include "ATen/ATen.h"
|
3 |
-
#include <c10/cuda/CUDAGuard.h>
|
4 |
-
typedef at::BFloat16 bf16;
|
5 |
-
typedef at::Half fp16;
|
6 |
-
typedef float fp32;
|
7 |
-
|
8 |
-
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
9 |
-
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
10 |
-
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
11 |
-
|
12 |
-
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
13 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
14 |
-
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
15 |
-
}
|
16 |
-
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
17 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
18 |
-
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
19 |
-
}
|
20 |
-
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
21 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
22 |
-
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
23 |
-
}
|
24 |
-
|
25 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
26 |
-
m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
|
27 |
-
m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
|
28 |
-
m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
|
29 |
-
}
|
30 |
-
TORCH_LIBRARY(rwkv6, m) {
|
31 |
-
m.def("forward_bf16", forward_bf16);
|
32 |
-
m.def("forward_fp16", forward_fp16);
|
33 |
-
m.def("forward_fp32", forward_fp32);
|
34 |
-
}
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "ATen/ATen.h"
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
typedef at::BFloat16 bf16;
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
11 |
+
|
12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
13 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
14 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
15 |
+
}
|
16 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
17 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
18 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
19 |
+
}
|
20 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
21 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
22 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
23 |
+
}
|
24 |
+
|
25 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
26 |
+
m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
|
27 |
+
m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
|
28 |
+
m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
|
29 |
+
}
|
30 |
+
TORCH_LIBRARY(rwkv6, m) {
|
31 |
+
m.def("forward_bf16", forward_bf16);
|
32 |
+
m.def("forward_fp16", forward_fp16);
|
33 |
+
m.def("forward_fp32", forward_fp32);
|
34 |
+
}
|
rwkv/cuda/rwkv7.cu
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
#include "ATen/ATen.h"
|
4 |
+
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef at::BFloat16 bf16;
|
7 |
+
typedef float fp32;
|
8 |
+
|
9 |
+
template <typename F>
|
10 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
|
11 |
+
float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
|
12 |
+
F *__restrict__ const _y)
|
13 |
+
{
|
14 |
+
const int e = blockIdx.x / H;
|
15 |
+
const int h = blockIdx.x % H;
|
16 |
+
const int i = threadIdx.x;
|
17 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
18 |
+
|
19 |
+
float state[_N_];
|
20 |
+
#pragma unroll
|
21 |
+
for (int j = 0; j < _N_; j++)
|
22 |
+
state[j] = _state[j];
|
23 |
+
|
24 |
+
__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
|
25 |
+
|
26 |
+
for (int _t = 0; _t < T; _t++)
|
27 |
+
{
|
28 |
+
const int t = e*T*C + h*_N_ + i + _t * C;
|
29 |
+
__syncthreads();
|
30 |
+
r[i] = float(_r[t]);
|
31 |
+
w[i] = __expf(-__expf(float(_w[t])));
|
32 |
+
k[i] = float(_k[t]);
|
33 |
+
a[i] = float(_a[t]);
|
34 |
+
b[i] = float(_b[t]);
|
35 |
+
__syncthreads();
|
36 |
+
|
37 |
+
float sa = 0;
|
38 |
+
#pragma unroll
|
39 |
+
for (int j = 0; j < _N_; j++)
|
40 |
+
{
|
41 |
+
sa += a[j] * state[j];
|
42 |
+
}
|
43 |
+
|
44 |
+
float vv = float(_v[t]);
|
45 |
+
float y = 0;
|
46 |
+
#pragma unroll
|
47 |
+
for (int j = 0; j < _N_; j++)
|
48 |
+
{
|
49 |
+
float& s = state[j];
|
50 |
+
s = s * w[j] + k[j] * vv + sa * b[j];
|
51 |
+
y += s * r[j];
|
52 |
+
}
|
53 |
+
_y[t] = F(y);
|
54 |
+
}
|
55 |
+
#pragma unroll
|
56 |
+
for (int j = 0; j < _N_; j++)
|
57 |
+
_state[j] = state[j];
|
58 |
+
}
|
59 |
+
|
60 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
|
61 |
+
{
|
62 |
+
assert(H*_N_ == C);
|
63 |
+
assert(B == 1); // only for B=1
|
64 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
|
65 |
+
}
|
66 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y)
|
67 |
+
{
|
68 |
+
assert(H*_N_ == C);
|
69 |
+
assert(B == 1); // only for B=1
|
70 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
|
71 |
+
}
|
72 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y)
|
73 |
+
{
|
74 |
+
assert(H*_N_ == C);
|
75 |
+
assert(B == 1); // only for B=1
|
76 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
|
77 |
+
}
|
rwkv/cuda/rwkv7_op.cpp
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "ATen/ATen.h"
|
3 |
+
|
4 |
+
typedef at::Half fp16;
|
5 |
+
typedef at::BFloat16 bf16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
|
9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y);
|
10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y);
|
11 |
+
|
12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
|
13 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
|
14 |
+
}
|
15 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
|
16 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), w.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), a.data_ptr<fp16>(), b.data_ptr<fp16>(), y.data_ptr<fp16>());
|
17 |
+
}
|
18 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
|
19 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), w.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), a.data_ptr<fp32>(), b.data_ptr<fp32>(), y.data_ptr<fp32>());
|
20 |
+
}
|
21 |
+
|
22 |
+
TORCH_LIBRARY(wkv7s, m) {
|
23 |
+
m.def("forward_bf16", forward_bf16);
|
24 |
+
m.def("forward_fp16", forward_fp16);
|
25 |
+
m.def("forward_fp32", forward_fp32);
|
26 |
+
}
|
{cuda β rwkv/cuda}/wrapper.cpp
RENAMED
@@ -1,141 +1,141 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
#include "ATen/ATen.h"
|
3 |
-
#include <iostream>
|
4 |
-
#include <c10/cuda/CUDAGuard.h>
|
5 |
-
|
6 |
-
typedef at::Half fp16;
|
7 |
-
|
8 |
-
template <typename F>
|
9 |
-
void cuda_wkv_forward(int B, int T, int C,
|
10 |
-
float *w, float *u, F *k, F *v, F *y,
|
11 |
-
float *aa, float *bb, float *pp);
|
12 |
-
template <typename F>
|
13 |
-
void cuda_mm8_seq(int B, int N, int M,
|
14 |
-
F *x, int x_stride,
|
15 |
-
uint8_t *w, int w_stride,
|
16 |
-
F *mx, F *rx,
|
17 |
-
F *my, F *ry,
|
18 |
-
F *y, int y_stride);
|
19 |
-
template <typename F>
|
20 |
-
void cuda_mm8_one(int N, int M,
|
21 |
-
F *x,
|
22 |
-
uint8_t *w, int w_stride,
|
23 |
-
F *mx, F *rx,
|
24 |
-
F *my, F *ry,
|
25 |
-
float *y);
|
26 |
-
|
27 |
-
void wkv_forward(int64_t B, int64_t T, int64_t C,
|
28 |
-
torch::Tensor &w, torch::Tensor &u,
|
29 |
-
torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
30 |
-
torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
|
31 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
32 |
-
switch (k.scalar_type()) {
|
33 |
-
case c10::ScalarType::Half:
|
34 |
-
cuda_wkv_forward(B, T, C,
|
35 |
-
w.data_ptr<float>(), u.data_ptr<float>(),
|
36 |
-
k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
|
37 |
-
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
38 |
-
break;
|
39 |
-
case c10::ScalarType::Float:
|
40 |
-
cuda_wkv_forward(B, T, C,
|
41 |
-
w.data_ptr<float>(), u.data_ptr<float>(),
|
42 |
-
k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
|
43 |
-
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
44 |
-
break;
|
45 |
-
default:
|
46 |
-
assert(false && "Only FP16 and FP32 are currently supported");
|
47 |
-
}
|
48 |
-
}
|
49 |
-
|
50 |
-
void mm8_seq(int64_t B, int64_t N, int64_t M,
|
51 |
-
torch::Tensor &x, torch::Tensor &w,
|
52 |
-
torch::Tensor &mx, torch::Tensor &rx,
|
53 |
-
torch::Tensor &my, torch::Tensor &ry,
|
54 |
-
torch::Tensor &y) {
|
55 |
-
assert(x.stride(1) == 1);
|
56 |
-
assert(w.stride(1) == 1);
|
57 |
-
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
58 |
-
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
59 |
-
assert(y.stride(1) == 1);
|
60 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
61 |
-
switch (x.scalar_type()) {
|
62 |
-
case c10::ScalarType::Half:
|
63 |
-
cuda_mm8_seq(
|
64 |
-
B, N, M,
|
65 |
-
x.data_ptr<fp16>(), x.stride(0),
|
66 |
-
w.data_ptr<uint8_t>(), w.stride(0),
|
67 |
-
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
68 |
-
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
69 |
-
y.data_ptr<fp16>(), y.stride(0));
|
70 |
-
break;
|
71 |
-
case c10::ScalarType::Float:
|
72 |
-
cuda_mm8_seq(
|
73 |
-
B, N, M,
|
74 |
-
x.data_ptr<float>(), x.stride(0),
|
75 |
-
w.data_ptr<uint8_t>(), w.stride(0),
|
76 |
-
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
77 |
-
my.data_ptr<float>(), ry.data_ptr<float>(),
|
78 |
-
y.data_ptr<float>(), y.stride(0));
|
79 |
-
break;
|
80 |
-
default:
|
81 |
-
assert(false && "Only FP16 and FP32 are currently supported");
|
82 |
-
}
|
83 |
-
}
|
84 |
-
void mm8_one(int64_t N, int64_t M,
|
85 |
-
torch::Tensor &x, torch::Tensor &w,
|
86 |
-
torch::Tensor &mx, torch::Tensor &rx,
|
87 |
-
torch::Tensor &my, torch::Tensor &ry,
|
88 |
-
torch::Tensor &y) {
|
89 |
-
assert(x.stride(0) == 1);
|
90 |
-
assert(w.stride(1) == 1);
|
91 |
-
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
92 |
-
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
93 |
-
assert(y.stride(0) == 1);
|
94 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
95 |
-
switch (x.scalar_type()) {
|
96 |
-
case c10::ScalarType::Half:
|
97 |
-
cuda_mm8_one(
|
98 |
-
N, M,
|
99 |
-
x.data_ptr<fp16>(),
|
100 |
-
w.data_ptr<uint8_t>(), w.stride(0),
|
101 |
-
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
102 |
-
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
103 |
-
y.data_ptr<float>());
|
104 |
-
break;
|
105 |
-
case c10::ScalarType::Float:
|
106 |
-
cuda_mm8_one(
|
107 |
-
N, M,
|
108 |
-
x.data_ptr<float>(),
|
109 |
-
w.data_ptr<uint8_t>(), w.stride(0),
|
110 |
-
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
111 |
-
my.data_ptr<float>(), ry.data_ptr<float>(),
|
112 |
-
y.data_ptr<float>());
|
113 |
-
break;
|
114 |
-
default:
|
115 |
-
assert(false && "Only FP16 and FP32 are currently supported");
|
116 |
-
}
|
117 |
-
}
|
118 |
-
|
119 |
-
using torch::Tensor;
|
120 |
-
|
121 |
-
#ifndef DISABLE_CUBLAS_GEMM
|
122 |
-
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
123 |
-
#endif
|
124 |
-
|
125 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
126 |
-
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
127 |
-
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
128 |
-
m.def("mm8_one", &mm8_one, "mm8 one");
|
129 |
-
#ifndef DISABLE_CUBLAS_GEMM
|
130 |
-
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
|
131 |
-
#endif
|
132 |
-
}
|
133 |
-
|
134 |
-
TORCH_LIBRARY(rwkv, m) {
|
135 |
-
m.def("wkv_forward", wkv_forward);
|
136 |
-
m.def("mm8_seq", mm8_seq);
|
137 |
-
m.def("mm8_one", mm8_one);
|
138 |
-
#ifndef DISABLE_CUBLAS_GEMM
|
139 |
-
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
|
140 |
-
#endif
|
141 |
-
}
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "ATen/ATen.h"
|
3 |
+
#include <iostream>
|
4 |
+
#include <c10/cuda/CUDAGuard.h>
|
5 |
+
|
6 |
+
typedef at::Half fp16;
|
7 |
+
|
8 |
+
template <typename F>
|
9 |
+
void cuda_wkv_forward(int B, int T, int C,
|
10 |
+
float *w, float *u, F *k, F *v, F *y,
|
11 |
+
float *aa, float *bb, float *pp);
|
12 |
+
template <typename F>
|
13 |
+
void cuda_mm8_seq(int B, int N, int M,
|
14 |
+
F *x, int x_stride,
|
15 |
+
uint8_t *w, int w_stride,
|
16 |
+
F *mx, F *rx,
|
17 |
+
F *my, F *ry,
|
18 |
+
F *y, int y_stride);
|
19 |
+
template <typename F>
|
20 |
+
void cuda_mm8_one(int N, int M,
|
21 |
+
F *x,
|
22 |
+
uint8_t *w, int w_stride,
|
23 |
+
F *mx, F *rx,
|
24 |
+
F *my, F *ry,
|
25 |
+
float *y);
|
26 |
+
|
27 |
+
void wkv_forward(int64_t B, int64_t T, int64_t C,
|
28 |
+
torch::Tensor &w, torch::Tensor &u,
|
29 |
+
torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
30 |
+
torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
|
31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
32 |
+
switch (k.scalar_type()) {
|
33 |
+
case c10::ScalarType::Half:
|
34 |
+
cuda_wkv_forward(B, T, C,
|
35 |
+
w.data_ptr<float>(), u.data_ptr<float>(),
|
36 |
+
k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
|
37 |
+
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
38 |
+
break;
|
39 |
+
case c10::ScalarType::Float:
|
40 |
+
cuda_wkv_forward(B, T, C,
|
41 |
+
w.data_ptr<float>(), u.data_ptr<float>(),
|
42 |
+
k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
|
43 |
+
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
44 |
+
break;
|
45 |
+
default:
|
46 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
void mm8_seq(int64_t B, int64_t N, int64_t M,
|
51 |
+
torch::Tensor &x, torch::Tensor &w,
|
52 |
+
torch::Tensor &mx, torch::Tensor &rx,
|
53 |
+
torch::Tensor &my, torch::Tensor &ry,
|
54 |
+
torch::Tensor &y) {
|
55 |
+
assert(x.stride(1) == 1);
|
56 |
+
assert(w.stride(1) == 1);
|
57 |
+
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
58 |
+
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
59 |
+
assert(y.stride(1) == 1);
|
60 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
61 |
+
switch (x.scalar_type()) {
|
62 |
+
case c10::ScalarType::Half:
|
63 |
+
cuda_mm8_seq(
|
64 |
+
B, N, M,
|
65 |
+
x.data_ptr<fp16>(), x.stride(0),
|
66 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
67 |
+
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
68 |
+
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
69 |
+
y.data_ptr<fp16>(), y.stride(0));
|
70 |
+
break;
|
71 |
+
case c10::ScalarType::Float:
|
72 |
+
cuda_mm8_seq(
|
73 |
+
B, N, M,
|
74 |
+
x.data_ptr<float>(), x.stride(0),
|
75 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
76 |
+
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
77 |
+
my.data_ptr<float>(), ry.data_ptr<float>(),
|
78 |
+
y.data_ptr<float>(), y.stride(0));
|
79 |
+
break;
|
80 |
+
default:
|
81 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
82 |
+
}
|
83 |
+
}
|
84 |
+
void mm8_one(int64_t N, int64_t M,
|
85 |
+
torch::Tensor &x, torch::Tensor &w,
|
86 |
+
torch::Tensor &mx, torch::Tensor &rx,
|
87 |
+
torch::Tensor &my, torch::Tensor &ry,
|
88 |
+
torch::Tensor &y) {
|
89 |
+
assert(x.stride(0) == 1);
|
90 |
+
assert(w.stride(1) == 1);
|
91 |
+
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
92 |
+
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
93 |
+
assert(y.stride(0) == 1);
|
94 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
95 |
+
switch (x.scalar_type()) {
|
96 |
+
case c10::ScalarType::Half:
|
97 |
+
cuda_mm8_one(
|
98 |
+
N, M,
|
99 |
+
x.data_ptr<fp16>(),
|
100 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
101 |
+
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
102 |
+
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
103 |
+
y.data_ptr<float>());
|
104 |
+
break;
|
105 |
+
case c10::ScalarType::Float:
|
106 |
+
cuda_mm8_one(
|
107 |
+
N, M,
|
108 |
+
x.data_ptr<float>(),
|
109 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
110 |
+
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
111 |
+
my.data_ptr<float>(), ry.data_ptr<float>(),
|
112 |
+
y.data_ptr<float>());
|
113 |
+
break;
|
114 |
+
default:
|
115 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
using torch::Tensor;
|
120 |
+
|
121 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
122 |
+
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
123 |
+
#endif
|
124 |
+
|
125 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
126 |
+
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
127 |
+
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
128 |
+
m.def("mm8_one", &mm8_one, "mm8 one");
|
129 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
130 |
+
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
|
131 |
+
#endif
|
132 |
+
}
|
133 |
+
|
134 |
+
TORCH_LIBRARY(rwkv, m) {
|
135 |
+
m.def("wkv_forward", wkv_forward);
|
136 |
+
m.def("mm8_seq", mm8_seq);
|
137 |
+
m.def("mm8_one", mm8_one);
|
138 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
139 |
+
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
|
140 |
+
#endif
|
141 |
+
}
|
rwkv/model.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
rwkv/rwkv_tokenizer.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
class TRIE:
|
6 |
+
__slots__ = tuple("ch,to,values,front".split(","))
|
7 |
+
to:list
|
8 |
+
values:set
|
9 |
+
def __init__(self, front=None, ch=None):
|
10 |
+
self.ch = ch
|
11 |
+
self.to = [None for ch in range(256)]
|
12 |
+
self.values = set()
|
13 |
+
self.front = front
|
14 |
+
|
15 |
+
def __repr__(self):
|
16 |
+
fr = self
|
17 |
+
ret = []
|
18 |
+
while(fr!=None):
|
19 |
+
if(fr.ch!=None):
|
20 |
+
ret.append(fr.ch)
|
21 |
+
fr = fr.front
|
22 |
+
return "<TRIE %s %s>"%(ret[::-1], self.values)
|
23 |
+
|
24 |
+
def add(self, key:bytes, idx:int=0, val=None):
|
25 |
+
if(idx == len(key)):
|
26 |
+
if(val is None):
|
27 |
+
val = key
|
28 |
+
self.values.add(val)
|
29 |
+
return self
|
30 |
+
ch = key[idx]
|
31 |
+
if(self.to[ch] is None):
|
32 |
+
self.to[ch] = TRIE(front=self, ch=ch)
|
33 |
+
return self.to[ch].add(key, idx=idx+1, val=val)
|
34 |
+
|
35 |
+
def find_longest(self, key:bytes, idx:int=0):
|
36 |
+
u:TRIE = self
|
37 |
+
ch:int = key[idx]
|
38 |
+
|
39 |
+
while(u.to[ch] is not None):
|
40 |
+
u = u.to[ch]
|
41 |
+
idx += 1
|
42 |
+
if(u.values):
|
43 |
+
ret = idx, u, u.values
|
44 |
+
if(idx==len(key)):
|
45 |
+
break
|
46 |
+
ch = key[idx]
|
47 |
+
return ret
|
48 |
+
|
49 |
+
class TRIE_TOKENIZER():
|
50 |
+
def __init__(self, file_name):
|
51 |
+
self.idx2token = {}
|
52 |
+
sorted = [] # must be already sorted
|
53 |
+
with open(file_name, "r", encoding="utf-8") as f:
|
54 |
+
lines = f.readlines()
|
55 |
+
for l in lines:
|
56 |
+
idx = int(l[:l.index(' ')])
|
57 |
+
x = eval(l[l.index(' '):l.rindex(' ')])
|
58 |
+
x = x.encode("utf-8") if isinstance(x, str) else x
|
59 |
+
assert isinstance(x, bytes)
|
60 |
+
assert len(x) == int(l[l.rindex(' '):])
|
61 |
+
sorted += [x]
|
62 |
+
self.idx2token[idx] = x
|
63 |
+
|
64 |
+
self.token2idx = {}
|
65 |
+
for k,v in self.idx2token.items():
|
66 |
+
self.token2idx[v] = int(k)
|
67 |
+
|
68 |
+
self.root = TRIE()
|
69 |
+
for t, i in self.token2idx.items():
|
70 |
+
_ = self.root.add(t, val=(t, i))
|
71 |
+
|
72 |
+
def encodeBytes(self, src:bytes):
|
73 |
+
idx:int = 0
|
74 |
+
tokens = []
|
75 |
+
while (idx < len(src)):
|
76 |
+
_idx:int = idx
|
77 |
+
idx, _, values = self.root.find_longest(src, idx)
|
78 |
+
assert(idx != _idx)
|
79 |
+
_, token = next(iter(values))
|
80 |
+
tokens.append(token)
|
81 |
+
return tokens
|
82 |
+
|
83 |
+
def decodeBytes(self, tokens):
|
84 |
+
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
85 |
+
|
86 |
+
def encode(self, src):
|
87 |
+
return self.encodeBytes(src.encode("utf-8"))
|
88 |
+
|
89 |
+
def decode(self, tokens):
|
90 |
+
try:
|
91 |
+
return self.decodeBytes(tokens).decode('utf-8')
|
92 |
+
except:
|
93 |
+
return '\ufffd' # bad utf-8
|
94 |
+
|
95 |
+
def printTokens(self, tokens):
|
96 |
+
for i in tokens:
|
97 |
+
s = self.idx2token[i]
|
98 |
+
try:
|
99 |
+
s = s.decode('utf-8')
|
100 |
+
except:
|
101 |
+
pass
|
102 |
+
print(f'{repr(s)}{i}', end=' ')
|
103 |
+
print()
|
rwkv/rwkv_vocab_v20230424.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
rwkv/utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
import os, sys
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
class PIPELINE_ARGS():
|
11 |
+
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, alpha_decay=0.996, token_ban=[], token_stop=[], chunk_len=256):
|
12 |
+
self.temperature = temperature
|
13 |
+
self.top_p = top_p
|
14 |
+
self.top_k = top_k
|
15 |
+
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
|
16 |
+
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
|
17 |
+
self.alpha_decay = alpha_decay # gradually decay the penalty
|
18 |
+
self.token_ban = token_ban # ban the generation of some tokens
|
19 |
+
self.token_stop = token_stop # stop generation whenever you see any token here
|
20 |
+
self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
|
21 |
+
|
22 |
+
class PIPELINE():
|
23 |
+
def __init__(self, model, WORD_NAME):
|
24 |
+
self.model = model
|
25 |
+
if WORD_NAME == 'cl100k_base':
|
26 |
+
import tiktoken
|
27 |
+
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
|
28 |
+
elif WORD_NAME == 'rwkv_vocab_v20230424':
|
29 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
30 |
+
from rwkv_tokenizer import TRIE_TOKENIZER
|
31 |
+
self.tokenizer = TRIE_TOKENIZER(os.path.dirname(os.path.abspath(__file__)) + '/rwkv_vocab_v20230424.txt')
|
32 |
+
else:
|
33 |
+
from tokenizers import Tokenizer
|
34 |
+
self.tokenizer = Tokenizer.from_file(WORD_NAME)
|
35 |
+
|
36 |
+
def refine_context(self, context):
|
37 |
+
context = context.strip().split('\n')
|
38 |
+
for c in range(len(context)):
|
39 |
+
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
40 |
+
context = list(filter(lambda c: c != '', context))
|
41 |
+
context = '\n' + ('\n'.join(context)).strip()
|
42 |
+
if context == '':
|
43 |
+
context = '\n'
|
44 |
+
return context
|
45 |
+
|
46 |
+
def encode(self, x):
|
47 |
+
if 'Tokenizer' in str(type(self.tokenizer)):
|
48 |
+
return self.tokenizer.encode(x).ids
|
49 |
+
else:
|
50 |
+
return self.tokenizer.encode(x)
|
51 |
+
|
52 |
+
def decode(self, x):
|
53 |
+
return self.tokenizer.decode(x)
|
54 |
+
|
55 |
+
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
56 |
+
if temperature == 0:
|
57 |
+
temperature = 1.0
|
58 |
+
top_p = 0
|
59 |
+
probs = F.softmax(logits.float(), dim=-1)
|
60 |
+
top_k = int(top_k)
|
61 |
+
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
|
62 |
+
if probs.device.type in ['cpu', 'privateuseone']:
|
63 |
+
probs = probs.cpu().numpy()
|
64 |
+
sorted_ids = np.argsort(probs)
|
65 |
+
sorted_probs = probs[sorted_ids][::-1]
|
66 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
67 |
+
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
|
68 |
+
probs[probs < cutoff] = 0
|
69 |
+
if top_k < len(probs) and top_k > 0:
|
70 |
+
probs[sorted_ids[:-top_k]] = 0
|
71 |
+
if temperature != 1.0:
|
72 |
+
probs = probs ** (1.0 / temperature)
|
73 |
+
probs = probs / np.sum(probs)
|
74 |
+
out = np.random.choice(a=len(probs), p=probs)
|
75 |
+
return int(out)
|
76 |
+
else:
|
77 |
+
sorted_ids = torch.argsort(probs)
|
78 |
+
sorted_probs = probs[sorted_ids]
|
79 |
+
sorted_probs = torch.flip(sorted_probs, dims=(0,))
|
80 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
81 |
+
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
|
82 |
+
probs[probs < cutoff] = 0
|
83 |
+
if top_k < len(probs) and top_k > 0:
|
84 |
+
probs[sorted_ids[:-top_k]] = 0
|
85 |
+
if temperature != 1.0:
|
86 |
+
probs = probs ** (1.0 / temperature)
|
87 |
+
out = torch.multinomial(probs, num_samples=1)[0]
|
88 |
+
return int(out)
|
89 |
+
|
90 |
+
def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):
|
91 |
+
all_tokens = []
|
92 |
+
out_last = 0
|
93 |
+
out_str = ''
|
94 |
+
occurrence = {}
|
95 |
+
for i in range(token_count):
|
96 |
+
|
97 |
+
# forward & adjust prob.
|
98 |
+
tokens = self.encode(ctx) if i == 0 else [token]
|
99 |
+
while len(tokens) > 0:
|
100 |
+
out, state = self.model.forward(tokens[:args.chunk_len], state)
|
101 |
+
tokens = tokens[args.chunk_len:]
|
102 |
+
|
103 |
+
for n in args.token_ban:
|
104 |
+
out[n] = -float('inf')
|
105 |
+
for n in occurrence:
|
106 |
+
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
107 |
+
|
108 |
+
# sampler
|
109 |
+
token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
|
110 |
+
if token in args.token_stop:
|
111 |
+
break
|
112 |
+
all_tokens += [token]
|
113 |
+
for xxx in occurrence:
|
114 |
+
occurrence[xxx] *= args.alpha_decay
|
115 |
+
|
116 |
+
ttt = self.decode([token])
|
117 |
+
www = 1
|
118 |
+
if ttt in ' \t0123456789':
|
119 |
+
www = 0
|
120 |
+
# elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{}οΌγοΌββοΌοΌοΌοΌοΌγγ':
|
121 |
+
# www = 0.5
|
122 |
+
if token not in occurrence:
|
123 |
+
occurrence[token] = www
|
124 |
+
else:
|
125 |
+
occurrence[token] += www
|
126 |
+
# print(occurrence) # debug
|
127 |
+
|
128 |
+
# output
|
129 |
+
tmp = self.decode(all_tokens[out_last:])
|
130 |
+
if '\ufffd' not in tmp: # is valid utf-8 string?
|
131 |
+
if callback:
|
132 |
+
callback(tmp)
|
133 |
+
out_str += tmp
|
134 |
+
out_last = i + 1
|
135 |
+
return out_str
|