sparkleman commited on
Commit
3b5e872
Β·
1 Parent(s): aeaf225
.gitignore CHANGED
@@ -16,4 +16,6 @@ wheels/
16
  *.st
17
  *local*
18
 
19
- dist-frontend/
 
 
 
16
  *.st
17
  *local*
18
 
19
+ dist-frontend/
20
+
21
+ .vscode/
app.py CHANGED
@@ -99,12 +99,6 @@ for model_config in CONFIG.MODELS:
99
  )
100
  logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
101
 
102
- tmp_model = RWKV(
103
- model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
104
- strategy=CONFIG.STRATEGY,
105
- )
106
- tmp_pipeline = PIPELINE(tmp_model, model_config.VOCAB)
107
-
108
  if model_config.DEFAULT_CHAT:
109
  if DEFALUT_MODEL_NAME != None:
110
  logger.info(
@@ -123,8 +117,16 @@ for model_config in CONFIG.MODELS:
123
 
124
  MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
125
  MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
126
- MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
127
- MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = tmp_pipeline
 
 
 
 
 
 
 
 
128
  logGPUState()
129
 
130
 
 
99
  )
100
  logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
101
 
 
 
 
 
 
 
102
  if model_config.DEFAULT_CHAT:
103
  if DEFALUT_MODEL_NAME != None:
104
  logger.info(
 
117
 
118
  MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
119
  MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
120
+ MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
121
+ model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
122
+ strategy=CONFIG.STRATEGY,
123
+ )
124
+ MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
125
+ MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
126
+ )
127
+ if "cuda" in CONFIG.STRATEGY:
128
+ # torch.cuda.empty_cache()
129
+ gc.collect()
130
  logGPUState()
131
 
132
 
pyproject.toml CHANGED
@@ -14,7 +14,7 @@ dependencies = [
14
  "pydantic-settings>=2.8.1",
15
  "pynvml>=12.0.0",
16
  "rich>=13.9.4",
17
- "rwkv==0.8.28",
18
  "setuptools>=75.8.2",
19
  "snowflake-id>=1.0.2",
20
  ]
 
14
  "pydantic-settings>=2.8.1",
15
  "pynvml>=12.0.0",
16
  "rich>=13.9.4",
17
+ # "rwkv==0.8.28",
18
  "setuptools>=75.8.2",
19
  "snowflake-id>=1.0.2",
20
  ]
rwkv/__init__.py ADDED
File without changes
{cuda β†’ rwkv/cuda}/gemm_fp16_cublas.cpp RENAMED
@@ -1,75 +1,75 @@
1
- #include <cublas_v2.h>
2
- #include <cuda.h>
3
- #include <cuda_fp16.h>
4
- #include <cuda_runtime.h>
5
- #include <torch/extension.h>
6
- #include <c10/cuda/CUDAGuard.h>
7
- #include <ATen/cuda/CUDAContext.h>
8
-
9
- #define CUBLAS_CHECK(condition) \
10
- for (cublasStatus_t _cublas_check_status = (condition); \
11
- _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
12
- throw std::runtime_error("cuBLAS error " + \
13
- std::to_string(_cublas_check_status) + " at " + \
14
- std::to_string(__LINE__));
15
-
16
- #define CUDA_CHECK(condition) \
17
- for (cudaError_t _cuda_check_status = (condition); \
18
- _cuda_check_status != cudaSuccess;) \
19
- throw std::runtime_error( \
20
- "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
21
- " at " + std::to_string(__LINE__));
22
-
23
- /*
24
- NOTE: blas gemm is column-major by default, but we need row-major output.
25
- The data of row-major, transposed matrix is exactly the same as the
26
- column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
27
- */
28
- void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
29
- const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
30
- const auto cuda_data_type = CUDA_R_16F;
31
- const auto cuda_c_data_type =
32
- c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
33
- const auto compute_type = CUDA_R_32F;
34
- const float sp_alpha = 1.f;
35
- // swap a and b, and use CUBLAS_OP_N. see the notes above
36
- std::swap(a, b);
37
- const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
38
- const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
39
- // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
40
- // negative axis is used because of the existence of batch matmul.
41
- const int m = a.size(-1);
42
- const int k = a.size(-2);
43
- const int n = b.size(-2);
44
- const int cublas_lda = m;
45
- const int cublas_ldb = k;
46
- const int cublas_ldc = m;
47
- cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
48
-
49
- #if CUDA_VERSION >= 11000
50
- cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
51
- #else
52
- cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
53
- #endif
54
- const float sp_beta = 0.f;
55
- if (a.sizes().size() == 2 && b.sizes().size() == 2) {
56
- CUBLAS_CHECK(cublasGemmEx(
57
- cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
58
- a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
59
- cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
60
- compute_type, algo));
61
- } else {
62
- // batch matmul
63
- assert(a.sizes().size() == 3 && b.sizes().size() == 3);
64
-
65
- const long long int cublas_stride_a = m * k;
66
- const long long int cublas_stride_b = k * n;
67
- const long long int cublas_stride_c = m * n;
68
- CUBLAS_CHECK(cublasGemmStridedBatchedEx(
69
- cublas_handle, cublas_trans_a, cublas_trans_b, m,
70
- n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
71
- cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
72
- &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
73
- a.size(0), compute_type, algo));
74
- }
75
- }
 
1
+ #include <cublas_v2.h>
2
+ #include <cuda.h>
3
+ #include <cuda_fp16.h>
4
+ #include <cuda_runtime.h>
5
+ #include <torch/extension.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <ATen/cuda/CUDAContext.h>
8
+
9
+ #define CUBLAS_CHECK(condition) \
10
+ for (cublasStatus_t _cublas_check_status = (condition); \
11
+ _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
12
+ throw std::runtime_error("cuBLAS error " + \
13
+ std::to_string(_cublas_check_status) + " at " + \
14
+ std::to_string(__LINE__));
15
+
16
+ #define CUDA_CHECK(condition) \
17
+ for (cudaError_t _cuda_check_status = (condition); \
18
+ _cuda_check_status != cudaSuccess;) \
19
+ throw std::runtime_error( \
20
+ "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
21
+ " at " + std::to_string(__LINE__));
22
+
23
+ /*
24
+ NOTE: blas gemm is column-major by default, but we need row-major output.
25
+ The data of row-major, transposed matrix is exactly the same as the
26
+ column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
27
+ */
28
+ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
29
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
30
+ const auto cuda_data_type = CUDA_R_16F;
31
+ const auto cuda_c_data_type =
32
+ c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
33
+ const auto compute_type = CUDA_R_32F;
34
+ const float sp_alpha = 1.f;
35
+ // swap a and b, and use CUBLAS_OP_N. see the notes above
36
+ std::swap(a, b);
37
+ const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
38
+ const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
39
+ // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
40
+ // negative axis is used because of the existence of batch matmul.
41
+ const int m = a.size(-1);
42
+ const int k = a.size(-2);
43
+ const int n = b.size(-2);
44
+ const int cublas_lda = m;
45
+ const int cublas_ldb = k;
46
+ const int cublas_ldc = m;
47
+ cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
48
+
49
+ #if CUDA_VERSION >= 11000
50
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
51
+ #else
52
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
53
+ #endif
54
+ const float sp_beta = 0.f;
55
+ if (a.sizes().size() == 2 && b.sizes().size() == 2) {
56
+ CUBLAS_CHECK(cublasGemmEx(
57
+ cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
58
+ a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
59
+ cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
60
+ compute_type, algo));
61
+ } else {
62
+ // batch matmul
63
+ assert(a.sizes().size() == 3 && b.sizes().size() == 3);
64
+
65
+ const long long int cublas_stride_a = m * k;
66
+ const long long int cublas_stride_b = k * n;
67
+ const long long int cublas_stride_c = m * n;
68
+ CUBLAS_CHECK(cublasGemmStridedBatchedEx(
69
+ cublas_handle, cublas_trans_a, cublas_trans_b, m,
70
+ n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
71
+ cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
72
+ &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
73
+ a.size(0), compute_type, algo));
74
+ }
75
+ }
{cuda β†’ rwkv/cuda}/operators.cu RENAMED
@@ -1,246 +1,246 @@
1
- #include <stdio.h>
2
- #include <assert.h>
3
- #include "ATen/ATen.h"
4
- #include <cuda_fp16.h>
5
- #define MIN_VALUE (-1e38)
6
- typedef at::Half fp16;
7
- __half *cast(fp16 *ptr) {
8
- return reinterpret_cast<__half *>(ptr);
9
- }
10
-
11
- template <typename F>
12
- __global__ void kernel_wkv_forward(const int B, const int T, const int C,
13
- const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
14
- F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
15
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
16
- const int _b = idx / C;
17
- const int _c = idx % C;
18
- const int _offset = _b * T * C + _c;
19
- const int _state_offset = _b * C + _c;
20
-
21
- float u = _u[_c];
22
- float w = _w[_c];
23
- const F *__restrict__ const k = _k + _offset;
24
- const F *__restrict__ const v = _v + _offset;
25
- F *__restrict__ const y = _y + _offset;
26
-
27
- float aa = _aa[_state_offset];
28
- float bb = _bb[_state_offset];
29
- float pp = _pp[_state_offset];
30
- for (int i = 0; i < T; i++) {
31
- const int ii = i * C;
32
- const float kk = float(k[ii]);
33
- const float vv = float(v[ii]);
34
- float ww = u + kk;
35
- float p = max(pp, ww);
36
- float e1 = exp(pp - p);
37
- float e2 = exp(ww - p);
38
- y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
39
- ww = w + pp;
40
- p = max(ww, kk);
41
- e1 = exp(ww - p);
42
- e2 = exp(kk - p);
43
- aa = e1 * aa + e2 * vv;
44
- bb = e1 * bb + e2;
45
- pp = p;
46
- }
47
- _aa[_state_offset] = aa;
48
- _bb[_state_offset] = bb;
49
- _pp[_state_offset] = pp;
50
- }
51
-
52
- template <typename F>
53
- void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
54
- dim3 threadsPerBlock( min(C, 32) );
55
- assert(B * C % threadsPerBlock.x == 0);
56
- dim3 numBlocks(B * C / threadsPerBlock.x);
57
- kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
58
- }
59
-
60
- template void cuda_wkv_forward<fp16>(
61
- int B, int T, int C,
62
- float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
63
- float *aa, float *bb, float *pp);
64
- template void cuda_wkv_forward<float>(
65
- int B, int T, int C,
66
- float *w, float *u, float *k, float *v, float *y,
67
- float *aa, float *bb, float *pp);
68
-
69
- __global__ void kernel_mm_seq_fp32i8(
70
- const int B, const int N, const int M,
71
- const float *__restrict__ const x, const int x_stride,
72
- const uint8_t *__restrict__ const w, const int w_stride,
73
- const float *__restrict__ const mx,
74
- const float *__restrict__ const rx,
75
- const float *__restrict__ const my,
76
- const float *__restrict__ const ry,
77
- float *__restrict__ const y, const int y_stride) {
78
-
79
- const int i = blockIdx.x * blockDim.x + threadIdx.x;
80
- const int k = blockIdx.y * blockDim.y + threadIdx.y;
81
-
82
- if (i < B && k < M) {
83
- float y_local = 0;
84
- for (int j = 0; j < N; ++j) {
85
- y_local += x[i * x_stride + j] * (
86
- (float(w[j * w_stride + k]) + 0.5f)
87
- * rx[k] * ry[j] + mx[k] + my[j]
88
- );
89
- }
90
- y[i * y_stride + k] = y_local;
91
- }
92
- }
93
-
94
- template <typename F>
95
- void cuda_mm8_seq(int B, int N, int M,
96
- F *x, int x_stride,
97
- uint8_t *w, int w_stride,
98
- F *mx, F *rx,
99
- F *my, F *ry,
100
- F *y, int y_stride);
101
-
102
- template <>
103
- void cuda_mm8_seq<float>(int B, int N, int M,
104
- float *x, int x_stride,
105
- uint8_t *w, int w_stride,
106
- float *mx, float *rx,
107
- float *my, float *ry,
108
- float *y, int y_stride) {
109
- dim3 blockSize(1, 128);
110
- dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
111
- kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
112
- B, N, M, x, x_stride, w, w_stride,
113
- mx, rx, my, ry, y, y_stride);
114
- }
115
-
116
- __global__ void kernel_mm_seq_fp16i8(
117
- const int B, const int N, const int M,
118
- const __half *__restrict__ const x, const int x_stride,
119
- const uint8_t *__restrict__ const w, const int w_stride,
120
- const __half *__restrict__ const mx,
121
- const __half *__restrict__ const rx,
122
- const __half *__restrict__ const my,
123
- const __half *__restrict__ const ry,
124
- __half *__restrict__ const y, const int y_stride) {
125
-
126
- const int i = blockIdx.x * blockDim.x + threadIdx.x;
127
- const int k = blockIdx.y * blockDim.y + threadIdx.y;
128
-
129
- if (i < B && k < M) {
130
- float y_local = 0;
131
- for (int j = 0; j < N; ++j) {
132
- y_local += __half2float(x[i * x_stride + j]) * (
133
- (float(w[j * w_stride + k]) + 0.5f)
134
- * __half2float(rx[k]) * __half2float(ry[j])
135
- + __half2float(mx[k]) + __half2float(my[j])
136
- );
137
- }
138
- y[i * y_stride + k] = __float2half(y_local);
139
- }
140
- }
141
-
142
- template <>
143
- void cuda_mm8_seq<fp16>(int B, int N, int M,
144
- fp16 *x, int x_stride,
145
- uint8_t *w, int w_stride,
146
- fp16 *mx, fp16 *rx,
147
- fp16 *my, fp16 *ry,
148
- fp16 *y, int y_stride) {
149
- dim3 blockSize(1, 128);
150
- dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
151
- kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
152
- B, N, M, cast(x), x_stride, w, w_stride,
153
- cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
154
- }
155
-
156
- #define MM8_ONE_JSPLIT 24
157
- #define MM8_ONE_TILE 1024
158
-
159
- __global__ void kernel_mm_one_fp32i8(
160
- const int N, const int M,
161
- const float *__restrict__ const x,
162
- const uint8_t *__restrict__ const w, const int w_stride,
163
- const float *__restrict__ const mx,
164
- const float *__restrict__ const rx,
165
- const float *__restrict__ const my,
166
- const float *__restrict__ const ry,
167
- float *__restrict__ const y) {
168
-
169
- const int k = blockIdx.y * blockDim.y + threadIdx.y;
170
- const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
171
- const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
172
-
173
- if (k < M) {
174
- float y_local = 0;
175
- for (int j = j0; j < j1; ++j) {
176
- y_local += x[j] * (
177
- (float(w[j * w_stride + k]) + 0.5f)
178
- * rx[k] * ry[j] + mx[k] + my[j]
179
- );
180
- }
181
- atomicAdd(&y[k], y_local);
182
- }
183
- }
184
-
185
- template <typename F>
186
- void cuda_mm8_one(int N, int M,
187
- F *x,
188
- uint8_t *w, int w_stride,
189
- F *mx, F *rx,
190
- F *my, F *ry,
191
- float *y);
192
-
193
- template <>
194
- void cuda_mm8_one<float>(int N, int M,
195
- float *x,
196
- uint8_t *w, int w_stride,
197
- float *mx, float *rx,
198
- float *my, float *ry,
199
- float *y) {
200
- dim3 blockSize(1, MM8_ONE_TILE);
201
- dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
202
- kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
203
- N, M, x, w, w_stride,
204
- mx, rx, my, ry, y);
205
- }
206
-
207
- __global__ void kernel_mm_one_fp16i8(
208
- const int N, const int M,
209
- const __half *__restrict__ const x,
210
- const uint8_t *__restrict__ const w, const int w_stride,
211
- const __half *__restrict__ const mx,
212
- const __half *__restrict__ const rx,
213
- const __half *__restrict__ const my,
214
- const __half *__restrict__ const ry,
215
- float *__restrict__ const y) {
216
-
217
- const int k = blockIdx.y * blockDim.y + threadIdx.y;
218
- const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
219
- const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
220
-
221
- if (k < M) {
222
- float y_local = 0;
223
- for (int j = j0; j < j1; ++j) {
224
- y_local += __half2float(x[j]) * (
225
- (float(w[j * w_stride + k]) + 0.5f)
226
- * __half2float(rx[k]) * __half2float(ry[j])
227
- + __half2float(mx[k]) + __half2float(my[j])
228
- );
229
- }
230
- atomicAdd(&y[k], y_local);
231
- }
232
- }
233
-
234
- template <>
235
- void cuda_mm8_one<fp16>(int N, int M,
236
- fp16 *x,
237
- uint8_t *w, int w_stride,
238
- fp16 *mx, fp16 *rx,
239
- fp16 *my, fp16 *ry,
240
- float *y) {
241
- dim3 blockSize(1, MM8_ONE_TILE);
242
- dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
243
- kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
244
- N, M, cast(x), w, w_stride,
245
- cast(mx), cast(rx), cast(my), cast(ry), y);
246
- }
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ #include <cuda_fp16.h>
5
+ #define MIN_VALUE (-1e38)
6
+ typedef at::Half fp16;
7
+ __half *cast(fp16 *ptr) {
8
+ return reinterpret_cast<__half *>(ptr);
9
+ }
10
+
11
+ template <typename F>
12
+ __global__ void kernel_wkv_forward(const int B, const int T, const int C,
13
+ const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
14
+ F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
15
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
16
+ const int _b = idx / C;
17
+ const int _c = idx % C;
18
+ const int _offset = _b * T * C + _c;
19
+ const int _state_offset = _b * C + _c;
20
+
21
+ float u = _u[_c];
22
+ float w = _w[_c];
23
+ const F *__restrict__ const k = _k + _offset;
24
+ const F *__restrict__ const v = _v + _offset;
25
+ F *__restrict__ const y = _y + _offset;
26
+
27
+ float aa = _aa[_state_offset];
28
+ float bb = _bb[_state_offset];
29
+ float pp = _pp[_state_offset];
30
+ for (int i = 0; i < T; i++) {
31
+ const int ii = i * C;
32
+ const float kk = float(k[ii]);
33
+ const float vv = float(v[ii]);
34
+ float ww = u + kk;
35
+ float p = max(pp, ww);
36
+ float e1 = exp(pp - p);
37
+ float e2 = exp(ww - p);
38
+ y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
39
+ ww = w + pp;
40
+ p = max(ww, kk);
41
+ e1 = exp(ww - p);
42
+ e2 = exp(kk - p);
43
+ aa = e1 * aa + e2 * vv;
44
+ bb = e1 * bb + e2;
45
+ pp = p;
46
+ }
47
+ _aa[_state_offset] = aa;
48
+ _bb[_state_offset] = bb;
49
+ _pp[_state_offset] = pp;
50
+ }
51
+
52
+ template <typename F>
53
+ void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
54
+ dim3 threadsPerBlock( min(C, 32) );
55
+ assert(B * C % threadsPerBlock.x == 0);
56
+ dim3 numBlocks(B * C / threadsPerBlock.x);
57
+ kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
58
+ }
59
+
60
+ template void cuda_wkv_forward<fp16>(
61
+ int B, int T, int C,
62
+ float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
63
+ float *aa, float *bb, float *pp);
64
+ template void cuda_wkv_forward<float>(
65
+ int B, int T, int C,
66
+ float *w, float *u, float *k, float *v, float *y,
67
+ float *aa, float *bb, float *pp);
68
+
69
+ __global__ void kernel_mm_seq_fp32i8(
70
+ const int B, const int N, const int M,
71
+ const float *__restrict__ const x, const int x_stride,
72
+ const uint8_t *__restrict__ const w, const int w_stride,
73
+ const float *__restrict__ const mx,
74
+ const float *__restrict__ const rx,
75
+ const float *__restrict__ const my,
76
+ const float *__restrict__ const ry,
77
+ float *__restrict__ const y, const int y_stride) {
78
+
79
+ const int i = blockIdx.x * blockDim.x + threadIdx.x;
80
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
81
+
82
+ if (i < B && k < M) {
83
+ float y_local = 0;
84
+ for (int j = 0; j < N; ++j) {
85
+ y_local += x[i * x_stride + j] * (
86
+ (float(w[j * w_stride + k]) + 0.5f)
87
+ * rx[k] * ry[j] + mx[k] + my[j]
88
+ );
89
+ }
90
+ y[i * y_stride + k] = y_local;
91
+ }
92
+ }
93
+
94
+ template <typename F>
95
+ void cuda_mm8_seq(int B, int N, int M,
96
+ F *x, int x_stride,
97
+ uint8_t *w, int w_stride,
98
+ F *mx, F *rx,
99
+ F *my, F *ry,
100
+ F *y, int y_stride);
101
+
102
+ template <>
103
+ void cuda_mm8_seq<float>(int B, int N, int M,
104
+ float *x, int x_stride,
105
+ uint8_t *w, int w_stride,
106
+ float *mx, float *rx,
107
+ float *my, float *ry,
108
+ float *y, int y_stride) {
109
+ dim3 blockSize(1, 128);
110
+ dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
111
+ kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
112
+ B, N, M, x, x_stride, w, w_stride,
113
+ mx, rx, my, ry, y, y_stride);
114
+ }
115
+
116
+ __global__ void kernel_mm_seq_fp16i8(
117
+ const int B, const int N, const int M,
118
+ const __half *__restrict__ const x, const int x_stride,
119
+ const uint8_t *__restrict__ const w, const int w_stride,
120
+ const __half *__restrict__ const mx,
121
+ const __half *__restrict__ const rx,
122
+ const __half *__restrict__ const my,
123
+ const __half *__restrict__ const ry,
124
+ __half *__restrict__ const y, const int y_stride) {
125
+
126
+ const int i = blockIdx.x * blockDim.x + threadIdx.x;
127
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
128
+
129
+ if (i < B && k < M) {
130
+ float y_local = 0;
131
+ for (int j = 0; j < N; ++j) {
132
+ y_local += __half2float(x[i * x_stride + j]) * (
133
+ (float(w[j * w_stride + k]) + 0.5f)
134
+ * __half2float(rx[k]) * __half2float(ry[j])
135
+ + __half2float(mx[k]) + __half2float(my[j])
136
+ );
137
+ }
138
+ y[i * y_stride + k] = __float2half(y_local);
139
+ }
140
+ }
141
+
142
+ template <>
143
+ void cuda_mm8_seq<fp16>(int B, int N, int M,
144
+ fp16 *x, int x_stride,
145
+ uint8_t *w, int w_stride,
146
+ fp16 *mx, fp16 *rx,
147
+ fp16 *my, fp16 *ry,
148
+ fp16 *y, int y_stride) {
149
+ dim3 blockSize(1, 128);
150
+ dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
151
+ kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
152
+ B, N, M, cast(x), x_stride, w, w_stride,
153
+ cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
154
+ }
155
+
156
+ #define MM8_ONE_JSPLIT 24
157
+ #define MM8_ONE_TILE 1024
158
+
159
+ __global__ void kernel_mm_one_fp32i8(
160
+ const int N, const int M,
161
+ const float *__restrict__ const x,
162
+ const uint8_t *__restrict__ const w, const int w_stride,
163
+ const float *__restrict__ const mx,
164
+ const float *__restrict__ const rx,
165
+ const float *__restrict__ const my,
166
+ const float *__restrict__ const ry,
167
+ float *__restrict__ const y) {
168
+
169
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
170
+ const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
171
+ const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
172
+
173
+ if (k < M) {
174
+ float y_local = 0;
175
+ for (int j = j0; j < j1; ++j) {
176
+ y_local += x[j] * (
177
+ (float(w[j * w_stride + k]) + 0.5f)
178
+ * rx[k] * ry[j] + mx[k] + my[j]
179
+ );
180
+ }
181
+ atomicAdd(&y[k], y_local);
182
+ }
183
+ }
184
+
185
+ template <typename F>
186
+ void cuda_mm8_one(int N, int M,
187
+ F *x,
188
+ uint8_t *w, int w_stride,
189
+ F *mx, F *rx,
190
+ F *my, F *ry,
191
+ float *y);
192
+
193
+ template <>
194
+ void cuda_mm8_one<float>(int N, int M,
195
+ float *x,
196
+ uint8_t *w, int w_stride,
197
+ float *mx, float *rx,
198
+ float *my, float *ry,
199
+ float *y) {
200
+ dim3 blockSize(1, MM8_ONE_TILE);
201
+ dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
202
+ kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
203
+ N, M, x, w, w_stride,
204
+ mx, rx, my, ry, y);
205
+ }
206
+
207
+ __global__ void kernel_mm_one_fp16i8(
208
+ const int N, const int M,
209
+ const __half *__restrict__ const x,
210
+ const uint8_t *__restrict__ const w, const int w_stride,
211
+ const __half *__restrict__ const mx,
212
+ const __half *__restrict__ const rx,
213
+ const __half *__restrict__ const my,
214
+ const __half *__restrict__ const ry,
215
+ float *__restrict__ const y) {
216
+
217
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
218
+ const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
219
+ const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
220
+
221
+ if (k < M) {
222
+ float y_local = 0;
223
+ for (int j = j0; j < j1; ++j) {
224
+ y_local += __half2float(x[j]) * (
225
+ (float(w[j * w_stride + k]) + 0.5f)
226
+ * __half2float(rx[k]) * __half2float(ry[j])
227
+ + __half2float(mx[k]) + __half2float(my[j])
228
+ );
229
+ }
230
+ atomicAdd(&y[k], y_local);
231
+ }
232
+ }
233
+
234
+ template <>
235
+ void cuda_mm8_one<fp16>(int N, int M,
236
+ fp16 *x,
237
+ uint8_t *w, int w_stride,
238
+ fp16 *mx, fp16 *rx,
239
+ fp16 *my, fp16 *ry,
240
+ float *y) {
241
+ dim3 blockSize(1, MM8_ONE_TILE);
242
+ dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
243
+ kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
244
+ N, M, cast(x), w, w_stride,
245
+ cast(mx), cast(rx), cast(my), cast(ry), y);
246
+ }
{cuda β†’ rwkv/cuda}/rwkv5.cu RENAMED
@@ -1,88 +1,88 @@
1
- #include <stdio.h>
2
- #include <assert.h>
3
- #include "ATen/ATen.h"
4
- typedef at::BFloat16 bf16;
5
- typedef at::Half fp16;
6
- typedef float fp32;
7
-
8
- template <typename F>
9
- __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
- const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
- F *__restrict__ const _y)
12
- {
13
- const int b = blockIdx.x / H;
14
- const int h = blockIdx.x % H;
15
- const int i = threadIdx.x;
16
- _w += h*_N_;
17
- _u += h*_N_;
18
- _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
19
-
20
- __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
21
-
22
- float state[_N_];
23
- #pragma unroll
24
- for (int j = 0; j < _N_; j++)
25
- state[j] = _state[j];
26
-
27
- __syncthreads();
28
- u[i] = float(_u[i]);
29
- w[i] = _w[i];
30
- __syncthreads();
31
-
32
- for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
33
- {
34
- __syncthreads();
35
- r[i] = float(_r[t]);
36
- k[i] = float(_k[t]);
37
- __syncthreads();
38
-
39
- const float v = float(_v[t]);
40
- float y = 0;
41
-
42
- #pragma unroll
43
- for (int j = 0; j < _N_; j+=4)
44
- {
45
- const float4& r_ = (float4&)(r[j]);
46
- const float4& k_ = (float4&)(k[j]);
47
- const float4& w_ = (float4&)(w[j]);
48
- const float4& u_ = (float4&)(u[j]);
49
- float4& s = (float4&)(state[j]);
50
- float4 x;
51
-
52
- x.x = k_.x * v;
53
- x.y = k_.y * v;
54
- x.z = k_.z * v;
55
- x.w = k_.w * v;
56
-
57
- y += r_.x * (u_.x * x.x + s.x);
58
- y += r_.y * (u_.y * x.y + s.y);
59
- y += r_.z * (u_.z * x.z + s.z);
60
- y += r_.w * (u_.w * x.w + s.w);
61
-
62
- s.x = s.x * w_.x + x.x;
63
- s.y = s.y * w_.y + x.y;
64
- s.z = s.z * w_.z + x.z;
65
- s.w = s.w * w_.w + x.w;
66
- }
67
- _y[t] = F(y);
68
- }
69
- #pragma unroll
70
- for (int j = 0; j < _N_; j++)
71
- _state[j] = state[j];
72
- }
73
-
74
- void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
75
- {
76
- assert(H*_N_ == C);
77
- kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
78
- }
79
- void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
80
- {
81
- assert(H*_N_ == C);
82
- kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
83
- }
84
- void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
85
- {
86
- assert(H*_N_ == C);
87
- kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
88
- }
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ template <typename F>
9
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
+ F *__restrict__ const _y)
12
+ {
13
+ const int b = blockIdx.x / H;
14
+ const int h = blockIdx.x % H;
15
+ const int i = threadIdx.x;
16
+ _w += h*_N_;
17
+ _u += h*_N_;
18
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
19
+
20
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
21
+
22
+ float state[_N_];
23
+ #pragma unroll
24
+ for (int j = 0; j < _N_; j++)
25
+ state[j] = _state[j];
26
+
27
+ __syncthreads();
28
+ u[i] = float(_u[i]);
29
+ w[i] = _w[i];
30
+ __syncthreads();
31
+
32
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
33
+ {
34
+ __syncthreads();
35
+ r[i] = float(_r[t]);
36
+ k[i] = float(_k[t]);
37
+ __syncthreads();
38
+
39
+ const float v = float(_v[t]);
40
+ float y = 0;
41
+
42
+ #pragma unroll
43
+ for (int j = 0; j < _N_; j+=4)
44
+ {
45
+ const float4& r_ = (float4&)(r[j]);
46
+ const float4& k_ = (float4&)(k[j]);
47
+ const float4& w_ = (float4&)(w[j]);
48
+ const float4& u_ = (float4&)(u[j]);
49
+ float4& s = (float4&)(state[j]);
50
+ float4 x;
51
+
52
+ x.x = k_.x * v;
53
+ x.y = k_.y * v;
54
+ x.z = k_.z * v;
55
+ x.w = k_.w * v;
56
+
57
+ y += r_.x * (u_.x * x.x + s.x);
58
+ y += r_.y * (u_.y * x.y + s.y);
59
+ y += r_.z * (u_.z * x.z + s.z);
60
+ y += r_.w * (u_.w * x.w + s.w);
61
+
62
+ s.x = s.x * w_.x + x.x;
63
+ s.y = s.y * w_.y + x.y;
64
+ s.z = s.z * w_.z + x.z;
65
+ s.w = s.w * w_.w + x.w;
66
+ }
67
+ _y[t] = F(y);
68
+ }
69
+ #pragma unroll
70
+ for (int j = 0; j < _N_; j++)
71
+ _state[j] = state[j];
72
+ }
73
+
74
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
75
+ {
76
+ assert(H*_N_ == C);
77
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
78
+ }
79
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
80
+ {
81
+ assert(H*_N_ == C);
82
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
83
+ }
84
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
85
+ {
86
+ assert(H*_N_ == C);
87
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
88
+ }
{cuda β†’ rwkv/cuda}/rwkv5_op.cpp RENAMED
@@ -1,34 +1,34 @@
1
- #include <torch/extension.h>
2
- #include "ATen/ATen.h"
3
- #include <c10/cuda/CUDAGuard.h>
4
- typedef at::BFloat16 bf16;
5
- typedef at::Half fp16;
6
- typedef float fp32;
7
-
8
- void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
- void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
- void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
-
12
- void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
- const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
- cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
- }
16
- void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
- const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
- cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
- }
20
- void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
- const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
- cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
- }
24
-
25
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
- m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
27
- m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
28
- m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
29
- }
30
- TORCH_LIBRARY(rwkv5, m) {
31
- m.def("forward_bf16", forward_bf16);
32
- m.def("forward_fp16", forward_fp16);
33
- m.def("forward_fp32", forward_fp32);
34
- }
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
+ }
16
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
+ }
20
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
+ }
24
+
25
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
+ m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
27
+ m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
28
+ m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
29
+ }
30
+ TORCH_LIBRARY(rwkv5, m) {
31
+ m.def("forward_bf16", forward_bf16);
32
+ m.def("forward_fp16", forward_fp16);
33
+ m.def("forward_fp32", forward_fp32);
34
+ }
{cuda β†’ rwkv/cuda}/rwkv6.cu RENAMED
@@ -1,87 +1,87 @@
1
- #include <stdio.h>
2
- #include <assert.h>
3
- #include "ATen/ATen.h"
4
- typedef at::BFloat16 bf16;
5
- typedef at::Half fp16;
6
- typedef float fp32;
7
-
8
- template <typename F>
9
- __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
- const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
- F *__restrict__ const _y)
12
- {
13
- const int b = blockIdx.x / H;
14
- const int h = blockIdx.x % H;
15
- const int i = threadIdx.x;
16
- _u += h*_N_;
17
- _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18
-
19
- __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
20
-
21
- float state[_N_];
22
- #pragma unroll
23
- for (int j = 0; j < _N_; j++)
24
- state[j] = _state[j];
25
-
26
- __syncthreads();
27
- u[i] = float(_u[i]);
28
- __syncthreads();
29
-
30
- for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
31
- {
32
- __syncthreads();
33
- w[i] = _w[t];
34
- r[i] = float(_r[t]);
35
- k[i] = float(_k[t]);
36
- __syncthreads();
37
-
38
- const float v = float(_v[t]);
39
- float y = 0;
40
-
41
- #pragma unroll
42
- for (int j = 0; j < _N_; j+=4)
43
- {
44
- const float4& r_ = (float4&)(r[j]);
45
- const float4& k_ = (float4&)(k[j]);
46
- const float4& w_ = (float4&)(w[j]);
47
- const float4& u_ = (float4&)(u[j]);
48
- float4& s = (float4&)(state[j]);
49
- float4 x;
50
-
51
- x.x = k_.x * v;
52
- x.y = k_.y * v;
53
- x.z = k_.z * v;
54
- x.w = k_.w * v;
55
-
56
- y += r_.x * (u_.x * x.x + s.x);
57
- y += r_.y * (u_.y * x.y + s.y);
58
- y += r_.z * (u_.z * x.z + s.z);
59
- y += r_.w * (u_.w * x.w + s.w);
60
-
61
- s.x = s.x * w_.x + x.x;
62
- s.y = s.y * w_.y + x.y;
63
- s.z = s.z * w_.z + x.z;
64
- s.w = s.w * w_.w + x.w;
65
- }
66
- _y[t] = F(y);
67
- }
68
- #pragma unroll
69
- for (int j = 0; j < _N_; j++)
70
- _state[j] = state[j];
71
- }
72
-
73
- void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
74
- {
75
- assert(H*_N_ == C);
76
- kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
77
- }
78
- void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
79
- {
80
- assert(H*_N_ == C);
81
- kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
82
- }
83
- void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
84
- {
85
- assert(H*_N_ == C);
86
- kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
87
- }
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ template <typename F>
9
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
+ F *__restrict__ const _y)
12
+ {
13
+ const int b = blockIdx.x / H;
14
+ const int h = blockIdx.x % H;
15
+ const int i = threadIdx.x;
16
+ _u += h*_N_;
17
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18
+
19
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
20
+
21
+ float state[_N_];
22
+ #pragma unroll
23
+ for (int j = 0; j < _N_; j++)
24
+ state[j] = _state[j];
25
+
26
+ __syncthreads();
27
+ u[i] = float(_u[i]);
28
+ __syncthreads();
29
+
30
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
31
+ {
32
+ __syncthreads();
33
+ w[i] = _w[t];
34
+ r[i] = float(_r[t]);
35
+ k[i] = float(_k[t]);
36
+ __syncthreads();
37
+
38
+ const float v = float(_v[t]);
39
+ float y = 0;
40
+
41
+ #pragma unroll
42
+ for (int j = 0; j < _N_; j+=4)
43
+ {
44
+ const float4& r_ = (float4&)(r[j]);
45
+ const float4& k_ = (float4&)(k[j]);
46
+ const float4& w_ = (float4&)(w[j]);
47
+ const float4& u_ = (float4&)(u[j]);
48
+ float4& s = (float4&)(state[j]);
49
+ float4 x;
50
+
51
+ x.x = k_.x * v;
52
+ x.y = k_.y * v;
53
+ x.z = k_.z * v;
54
+ x.w = k_.w * v;
55
+
56
+ y += r_.x * (u_.x * x.x + s.x);
57
+ y += r_.y * (u_.y * x.y + s.y);
58
+ y += r_.z * (u_.z * x.z + s.z);
59
+ y += r_.w * (u_.w * x.w + s.w);
60
+
61
+ s.x = s.x * w_.x + x.x;
62
+ s.y = s.y * w_.y + x.y;
63
+ s.z = s.z * w_.z + x.z;
64
+ s.w = s.w * w_.w + x.w;
65
+ }
66
+ _y[t] = F(y);
67
+ }
68
+ #pragma unroll
69
+ for (int j = 0; j < _N_; j++)
70
+ _state[j] = state[j];
71
+ }
72
+
73
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
74
+ {
75
+ assert(H*_N_ == C);
76
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
77
+ }
78
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
79
+ {
80
+ assert(H*_N_ == C);
81
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
82
+ }
83
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
84
+ {
85
+ assert(H*_N_ == C);
86
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
87
+ }
{cuda β†’ rwkv/cuda}/rwkv6_op.cpp RENAMED
@@ -1,34 +1,34 @@
1
- #include <torch/extension.h>
2
- #include "ATen/ATen.h"
3
- #include <c10/cuda/CUDAGuard.h>
4
- typedef at::BFloat16 bf16;
5
- typedef at::Half fp16;
6
- typedef float fp32;
7
-
8
- void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
- void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
- void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
-
12
- void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
- const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
- cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
- }
16
- void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
- const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
- cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
- }
20
- void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
- const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
- cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
- }
24
-
25
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
- m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
27
- m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
28
- m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
29
- }
30
- TORCH_LIBRARY(rwkv6, m) {
31
- m.def("forward_bf16", forward_bf16);
32
- m.def("forward_fp16", forward_fp16);
33
- m.def("forward_fp32", forward_fp32);
34
- }
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
+ }
16
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
+ }
20
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
+ }
24
+
25
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
+ m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
27
+ m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
28
+ m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
29
+ }
30
+ TORCH_LIBRARY(rwkv6, m) {
31
+ m.def("forward_bf16", forward_bf16);
32
+ m.def("forward_fp16", forward_fp16);
33
+ m.def("forward_fp32", forward_fp32);
34
+ }
rwkv/cuda/rwkv7.cu ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+
5
+ typedef at::Half fp16;
6
+ typedef at::BFloat16 bf16;
7
+ typedef float fp32;
8
+
9
+ template <typename F>
10
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H,
11
+ float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
12
+ F *__restrict__ const _y)
13
+ {
14
+ const int e = blockIdx.x / H;
15
+ const int h = blockIdx.x % H;
16
+ const int i = threadIdx.x;
17
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18
+
19
+ float state[_N_];
20
+ #pragma unroll
21
+ for (int j = 0; j < _N_; j++)
22
+ state[j] = _state[j];
23
+
24
+ __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
25
+
26
+ for (int _t = 0; _t < T; _t++)
27
+ {
28
+ const int t = e*T*C + h*_N_ + i + _t * C;
29
+ __syncthreads();
30
+ r[i] = float(_r[t]);
31
+ w[i] = __expf(-__expf(float(_w[t])));
32
+ k[i] = float(_k[t]);
33
+ a[i] = float(_a[t]);
34
+ b[i] = float(_b[t]);
35
+ __syncthreads();
36
+
37
+ float sa = 0;
38
+ #pragma unroll
39
+ for (int j = 0; j < _N_; j++)
40
+ {
41
+ sa += a[j] * state[j];
42
+ }
43
+
44
+ float vv = float(_v[t]);
45
+ float y = 0;
46
+ #pragma unroll
47
+ for (int j = 0; j < _N_; j++)
48
+ {
49
+ float& s = state[j];
50
+ s = s * w[j] + k[j] * vv + sa * b[j];
51
+ y += s * r[j];
52
+ }
53
+ _y[t] = F(y);
54
+ }
55
+ #pragma unroll
56
+ for (int j = 0; j < _N_; j++)
57
+ _state[j] = state[j];
58
+ }
59
+
60
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
61
+ {
62
+ assert(H*_N_ == C);
63
+ assert(B == 1); // only for B=1
64
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
65
+ }
66
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y)
67
+ {
68
+ assert(H*_N_ == C);
69
+ assert(B == 1); // only for B=1
70
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
71
+ }
72
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y)
73
+ {
74
+ assert(H*_N_ == C);
75
+ assert(B == 1); // only for B=1
76
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
77
+ }
rwkv/cuda/rwkv7_op.cpp ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+
4
+ typedef at::Half fp16;
5
+ typedef at::BFloat16 bf16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
13
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
14
+ }
15
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
16
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), w.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), a.data_ptr<fp16>(), b.data_ptr<fp16>(), y.data_ptr<fp16>());
17
+ }
18
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
19
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), w.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), a.data_ptr<fp32>(), b.data_ptr<fp32>(), y.data_ptr<fp32>());
20
+ }
21
+
22
+ TORCH_LIBRARY(wkv7s, m) {
23
+ m.def("forward_bf16", forward_bf16);
24
+ m.def("forward_fp16", forward_fp16);
25
+ m.def("forward_fp32", forward_fp32);
26
+ }
{cuda β†’ rwkv/cuda}/wrapper.cpp RENAMED
@@ -1,141 +1,141 @@
1
- #include <torch/extension.h>
2
- #include "ATen/ATen.h"
3
- #include <iostream>
4
- #include <c10/cuda/CUDAGuard.h>
5
-
6
- typedef at::Half fp16;
7
-
8
- template <typename F>
9
- void cuda_wkv_forward(int B, int T, int C,
10
- float *w, float *u, F *k, F *v, F *y,
11
- float *aa, float *bb, float *pp);
12
- template <typename F>
13
- void cuda_mm8_seq(int B, int N, int M,
14
- F *x, int x_stride,
15
- uint8_t *w, int w_stride,
16
- F *mx, F *rx,
17
- F *my, F *ry,
18
- F *y, int y_stride);
19
- template <typename F>
20
- void cuda_mm8_one(int N, int M,
21
- F *x,
22
- uint8_t *w, int w_stride,
23
- F *mx, F *rx,
24
- F *my, F *ry,
25
- float *y);
26
-
27
- void wkv_forward(int64_t B, int64_t T, int64_t C,
28
- torch::Tensor &w, torch::Tensor &u,
29
- torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
30
- torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
31
- const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
32
- switch (k.scalar_type()) {
33
- case c10::ScalarType::Half:
34
- cuda_wkv_forward(B, T, C,
35
- w.data_ptr<float>(), u.data_ptr<float>(),
36
- k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
37
- aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
38
- break;
39
- case c10::ScalarType::Float:
40
- cuda_wkv_forward(B, T, C,
41
- w.data_ptr<float>(), u.data_ptr<float>(),
42
- k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
43
- aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
44
- break;
45
- default:
46
- assert(false && "Only FP16 and FP32 are currently supported");
47
- }
48
- }
49
-
50
- void mm8_seq(int64_t B, int64_t N, int64_t M,
51
- torch::Tensor &x, torch::Tensor &w,
52
- torch::Tensor &mx, torch::Tensor &rx,
53
- torch::Tensor &my, torch::Tensor &ry,
54
- torch::Tensor &y) {
55
- assert(x.stride(1) == 1);
56
- assert(w.stride(1) == 1);
57
- assert(mx.stride(0) == 1 && rx.stride(0) == 1);
58
- assert(my.stride(0) == 1 && ry.stride(0) == 1);
59
- assert(y.stride(1) == 1);
60
- const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
61
- switch (x.scalar_type()) {
62
- case c10::ScalarType::Half:
63
- cuda_mm8_seq(
64
- B, N, M,
65
- x.data_ptr<fp16>(), x.stride(0),
66
- w.data_ptr<uint8_t>(), w.stride(0),
67
- mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
68
- my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
69
- y.data_ptr<fp16>(), y.stride(0));
70
- break;
71
- case c10::ScalarType::Float:
72
- cuda_mm8_seq(
73
- B, N, M,
74
- x.data_ptr<float>(), x.stride(0),
75
- w.data_ptr<uint8_t>(), w.stride(0),
76
- mx.data_ptr<float>(), rx.data_ptr<float>(),
77
- my.data_ptr<float>(), ry.data_ptr<float>(),
78
- y.data_ptr<float>(), y.stride(0));
79
- break;
80
- default:
81
- assert(false && "Only FP16 and FP32 are currently supported");
82
- }
83
- }
84
- void mm8_one(int64_t N, int64_t M,
85
- torch::Tensor &x, torch::Tensor &w,
86
- torch::Tensor &mx, torch::Tensor &rx,
87
- torch::Tensor &my, torch::Tensor &ry,
88
- torch::Tensor &y) {
89
- assert(x.stride(0) == 1);
90
- assert(w.stride(1) == 1);
91
- assert(mx.stride(0) == 1 && rx.stride(0) == 1);
92
- assert(my.stride(0) == 1 && ry.stride(0) == 1);
93
- assert(y.stride(0) == 1);
94
- const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
95
- switch (x.scalar_type()) {
96
- case c10::ScalarType::Half:
97
- cuda_mm8_one(
98
- N, M,
99
- x.data_ptr<fp16>(),
100
- w.data_ptr<uint8_t>(), w.stride(0),
101
- mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
102
- my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
103
- y.data_ptr<float>());
104
- break;
105
- case c10::ScalarType::Float:
106
- cuda_mm8_one(
107
- N, M,
108
- x.data_ptr<float>(),
109
- w.data_ptr<uint8_t>(), w.stride(0),
110
- mx.data_ptr<float>(), rx.data_ptr<float>(),
111
- my.data_ptr<float>(), ry.data_ptr<float>(),
112
- y.data_ptr<float>());
113
- break;
114
- default:
115
- assert(false && "Only FP16 and FP32 are currently supported");
116
- }
117
- }
118
-
119
- using torch::Tensor;
120
-
121
- #ifndef DISABLE_CUBLAS_GEMM
122
- void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
123
- #endif
124
-
125
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
126
- m.def("wkv_forward", &wkv_forward, "wkv forward");
127
- m.def("mm8_seq", &mm8_seq, "mm8 seq");
128
- m.def("mm8_one", &mm8_one, "mm8 one");
129
- #ifndef DISABLE_CUBLAS_GEMM
130
- m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
131
- #endif
132
- }
133
-
134
- TORCH_LIBRARY(rwkv, m) {
135
- m.def("wkv_forward", wkv_forward);
136
- m.def("mm8_seq", mm8_seq);
137
- m.def("mm8_one", mm8_one);
138
- #ifndef DISABLE_CUBLAS_GEMM
139
- m.def("gemm_fp16_cublas", gemm_fp16_cublas);
140
- #endif
141
- }
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <iostream>
4
+ #include <c10/cuda/CUDAGuard.h>
5
+
6
+ typedef at::Half fp16;
7
+
8
+ template <typename F>
9
+ void cuda_wkv_forward(int B, int T, int C,
10
+ float *w, float *u, F *k, F *v, F *y,
11
+ float *aa, float *bb, float *pp);
12
+ template <typename F>
13
+ void cuda_mm8_seq(int B, int N, int M,
14
+ F *x, int x_stride,
15
+ uint8_t *w, int w_stride,
16
+ F *mx, F *rx,
17
+ F *my, F *ry,
18
+ F *y, int y_stride);
19
+ template <typename F>
20
+ void cuda_mm8_one(int N, int M,
21
+ F *x,
22
+ uint8_t *w, int w_stride,
23
+ F *mx, F *rx,
24
+ F *my, F *ry,
25
+ float *y);
26
+
27
+ void wkv_forward(int64_t B, int64_t T, int64_t C,
28
+ torch::Tensor &w, torch::Tensor &u,
29
+ torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
30
+ torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
32
+ switch (k.scalar_type()) {
33
+ case c10::ScalarType::Half:
34
+ cuda_wkv_forward(B, T, C,
35
+ w.data_ptr<float>(), u.data_ptr<float>(),
36
+ k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
37
+ aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
38
+ break;
39
+ case c10::ScalarType::Float:
40
+ cuda_wkv_forward(B, T, C,
41
+ w.data_ptr<float>(), u.data_ptr<float>(),
42
+ k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
43
+ aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
44
+ break;
45
+ default:
46
+ assert(false && "Only FP16 and FP32 are currently supported");
47
+ }
48
+ }
49
+
50
+ void mm8_seq(int64_t B, int64_t N, int64_t M,
51
+ torch::Tensor &x, torch::Tensor &w,
52
+ torch::Tensor &mx, torch::Tensor &rx,
53
+ torch::Tensor &my, torch::Tensor &ry,
54
+ torch::Tensor &y) {
55
+ assert(x.stride(1) == 1);
56
+ assert(w.stride(1) == 1);
57
+ assert(mx.stride(0) == 1 && rx.stride(0) == 1);
58
+ assert(my.stride(0) == 1 && ry.stride(0) == 1);
59
+ assert(y.stride(1) == 1);
60
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
61
+ switch (x.scalar_type()) {
62
+ case c10::ScalarType::Half:
63
+ cuda_mm8_seq(
64
+ B, N, M,
65
+ x.data_ptr<fp16>(), x.stride(0),
66
+ w.data_ptr<uint8_t>(), w.stride(0),
67
+ mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
68
+ my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
69
+ y.data_ptr<fp16>(), y.stride(0));
70
+ break;
71
+ case c10::ScalarType::Float:
72
+ cuda_mm8_seq(
73
+ B, N, M,
74
+ x.data_ptr<float>(), x.stride(0),
75
+ w.data_ptr<uint8_t>(), w.stride(0),
76
+ mx.data_ptr<float>(), rx.data_ptr<float>(),
77
+ my.data_ptr<float>(), ry.data_ptr<float>(),
78
+ y.data_ptr<float>(), y.stride(0));
79
+ break;
80
+ default:
81
+ assert(false && "Only FP16 and FP32 are currently supported");
82
+ }
83
+ }
84
+ void mm8_one(int64_t N, int64_t M,
85
+ torch::Tensor &x, torch::Tensor &w,
86
+ torch::Tensor &mx, torch::Tensor &rx,
87
+ torch::Tensor &my, torch::Tensor &ry,
88
+ torch::Tensor &y) {
89
+ assert(x.stride(0) == 1);
90
+ assert(w.stride(1) == 1);
91
+ assert(mx.stride(0) == 1 && rx.stride(0) == 1);
92
+ assert(my.stride(0) == 1 && ry.stride(0) == 1);
93
+ assert(y.stride(0) == 1);
94
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
95
+ switch (x.scalar_type()) {
96
+ case c10::ScalarType::Half:
97
+ cuda_mm8_one(
98
+ N, M,
99
+ x.data_ptr<fp16>(),
100
+ w.data_ptr<uint8_t>(), w.stride(0),
101
+ mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
102
+ my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
103
+ y.data_ptr<float>());
104
+ break;
105
+ case c10::ScalarType::Float:
106
+ cuda_mm8_one(
107
+ N, M,
108
+ x.data_ptr<float>(),
109
+ w.data_ptr<uint8_t>(), w.stride(0),
110
+ mx.data_ptr<float>(), rx.data_ptr<float>(),
111
+ my.data_ptr<float>(), ry.data_ptr<float>(),
112
+ y.data_ptr<float>());
113
+ break;
114
+ default:
115
+ assert(false && "Only FP16 and FP32 are currently supported");
116
+ }
117
+ }
118
+
119
+ using torch::Tensor;
120
+
121
+ #ifndef DISABLE_CUBLAS_GEMM
122
+ void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
123
+ #endif
124
+
125
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
126
+ m.def("wkv_forward", &wkv_forward, "wkv forward");
127
+ m.def("mm8_seq", &mm8_seq, "mm8 seq");
128
+ m.def("mm8_one", &mm8_one, "mm8 one");
129
+ #ifndef DISABLE_CUBLAS_GEMM
130
+ m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
131
+ #endif
132
+ }
133
+
134
+ TORCH_LIBRARY(rwkv, m) {
135
+ m.def("wkv_forward", wkv_forward);
136
+ m.def("mm8_seq", mm8_seq);
137
+ m.def("mm8_one", mm8_one);
138
+ #ifndef DISABLE_CUBLAS_GEMM
139
+ m.def("gemm_fp16_cublas", gemm_fp16_cublas);
140
+ #endif
141
+ }
rwkv/model.py ADDED
The diff for this file is too large to render. See raw diff
 
rwkv/rwkv_tokenizer.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ class TRIE:
6
+ __slots__ = tuple("ch,to,values,front".split(","))
7
+ to:list
8
+ values:set
9
+ def __init__(self, front=None, ch=None):
10
+ self.ch = ch
11
+ self.to = [None for ch in range(256)]
12
+ self.values = set()
13
+ self.front = front
14
+
15
+ def __repr__(self):
16
+ fr = self
17
+ ret = []
18
+ while(fr!=None):
19
+ if(fr.ch!=None):
20
+ ret.append(fr.ch)
21
+ fr = fr.front
22
+ return "<TRIE %s %s>"%(ret[::-1], self.values)
23
+
24
+ def add(self, key:bytes, idx:int=0, val=None):
25
+ if(idx == len(key)):
26
+ if(val is None):
27
+ val = key
28
+ self.values.add(val)
29
+ return self
30
+ ch = key[idx]
31
+ if(self.to[ch] is None):
32
+ self.to[ch] = TRIE(front=self, ch=ch)
33
+ return self.to[ch].add(key, idx=idx+1, val=val)
34
+
35
+ def find_longest(self, key:bytes, idx:int=0):
36
+ u:TRIE = self
37
+ ch:int = key[idx]
38
+
39
+ while(u.to[ch] is not None):
40
+ u = u.to[ch]
41
+ idx += 1
42
+ if(u.values):
43
+ ret = idx, u, u.values
44
+ if(idx==len(key)):
45
+ break
46
+ ch = key[idx]
47
+ return ret
48
+
49
+ class TRIE_TOKENIZER():
50
+ def __init__(self, file_name):
51
+ self.idx2token = {}
52
+ sorted = [] # must be already sorted
53
+ with open(file_name, "r", encoding="utf-8") as f:
54
+ lines = f.readlines()
55
+ for l in lines:
56
+ idx = int(l[:l.index(' ')])
57
+ x = eval(l[l.index(' '):l.rindex(' ')])
58
+ x = x.encode("utf-8") if isinstance(x, str) else x
59
+ assert isinstance(x, bytes)
60
+ assert len(x) == int(l[l.rindex(' '):])
61
+ sorted += [x]
62
+ self.idx2token[idx] = x
63
+
64
+ self.token2idx = {}
65
+ for k,v in self.idx2token.items():
66
+ self.token2idx[v] = int(k)
67
+
68
+ self.root = TRIE()
69
+ for t, i in self.token2idx.items():
70
+ _ = self.root.add(t, val=(t, i))
71
+
72
+ def encodeBytes(self, src:bytes):
73
+ idx:int = 0
74
+ tokens = []
75
+ while (idx < len(src)):
76
+ _idx:int = idx
77
+ idx, _, values = self.root.find_longest(src, idx)
78
+ assert(idx != _idx)
79
+ _, token = next(iter(values))
80
+ tokens.append(token)
81
+ return tokens
82
+
83
+ def decodeBytes(self, tokens):
84
+ return b''.join(map(lambda i: self.idx2token[i], tokens))
85
+
86
+ def encode(self, src):
87
+ return self.encodeBytes(src.encode("utf-8"))
88
+
89
+ def decode(self, tokens):
90
+ try:
91
+ return self.decodeBytes(tokens).decode('utf-8')
92
+ except:
93
+ return '\ufffd' # bad utf-8
94
+
95
+ def printTokens(self, tokens):
96
+ for i in tokens:
97
+ s = self.idx2token[i]
98
+ try:
99
+ s = s.decode('utf-8')
100
+ except:
101
+ pass
102
+ print(f'{repr(s)}{i}', end=' ')
103
+ print()
rwkv/rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
rwkv/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ import os, sys
6
+ import numpy as np
7
+ import torch
8
+ from torch.nn import functional as F
9
+
10
+ class PIPELINE_ARGS():
11
+ def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, alpha_decay=0.996, token_ban=[], token_stop=[], chunk_len=256):
12
+ self.temperature = temperature
13
+ self.top_p = top_p
14
+ self.top_k = top_k
15
+ self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
16
+ self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
17
+ self.alpha_decay = alpha_decay # gradually decay the penalty
18
+ self.token_ban = token_ban # ban the generation of some tokens
19
+ self.token_stop = token_stop # stop generation whenever you see any token here
20
+ self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
21
+
22
+ class PIPELINE():
23
+ def __init__(self, model, WORD_NAME):
24
+ self.model = model
25
+ if WORD_NAME == 'cl100k_base':
26
+ import tiktoken
27
+ self.tokenizer = tiktoken.get_encoding(WORD_NAME)
28
+ elif WORD_NAME == 'rwkv_vocab_v20230424':
29
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
30
+ from rwkv_tokenizer import TRIE_TOKENIZER
31
+ self.tokenizer = TRIE_TOKENIZER(os.path.dirname(os.path.abspath(__file__)) + '/rwkv_vocab_v20230424.txt')
32
+ else:
33
+ from tokenizers import Tokenizer
34
+ self.tokenizer = Tokenizer.from_file(WORD_NAME)
35
+
36
+ def refine_context(self, context):
37
+ context = context.strip().split('\n')
38
+ for c in range(len(context)):
39
+ context[c] = context[c].strip().strip('\u3000').strip('\r')
40
+ context = list(filter(lambda c: c != '', context))
41
+ context = '\n' + ('\n'.join(context)).strip()
42
+ if context == '':
43
+ context = '\n'
44
+ return context
45
+
46
+ def encode(self, x):
47
+ if 'Tokenizer' in str(type(self.tokenizer)):
48
+ return self.tokenizer.encode(x).ids
49
+ else:
50
+ return self.tokenizer.encode(x)
51
+
52
+ def decode(self, x):
53
+ return self.tokenizer.decode(x)
54
+
55
+ def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
56
+ if temperature == 0:
57
+ temperature = 1.0
58
+ top_p = 0
59
+ probs = F.softmax(logits.float(), dim=-1)
60
+ top_k = int(top_k)
61
+ # 'privateuseone' is the type of custom devices like `torch_directml.device()`
62
+ if probs.device.type in ['cpu', 'privateuseone']:
63
+ probs = probs.cpu().numpy()
64
+ sorted_ids = np.argsort(probs)
65
+ sorted_probs = probs[sorted_ids][::-1]
66
+ cumulative_probs = np.cumsum(sorted_probs)
67
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
68
+ probs[probs < cutoff] = 0
69
+ if top_k < len(probs) and top_k > 0:
70
+ probs[sorted_ids[:-top_k]] = 0
71
+ if temperature != 1.0:
72
+ probs = probs ** (1.0 / temperature)
73
+ probs = probs / np.sum(probs)
74
+ out = np.random.choice(a=len(probs), p=probs)
75
+ return int(out)
76
+ else:
77
+ sorted_ids = torch.argsort(probs)
78
+ sorted_probs = probs[sorted_ids]
79
+ sorted_probs = torch.flip(sorted_probs, dims=(0,))
80
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
81
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
82
+ probs[probs < cutoff] = 0
83
+ if top_k < len(probs) and top_k > 0:
84
+ probs[sorted_ids[:-top_k]] = 0
85
+ if temperature != 1.0:
86
+ probs = probs ** (1.0 / temperature)
87
+ out = torch.multinomial(probs, num_samples=1)[0]
88
+ return int(out)
89
+
90
+ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):
91
+ all_tokens = []
92
+ out_last = 0
93
+ out_str = ''
94
+ occurrence = {}
95
+ for i in range(token_count):
96
+
97
+ # forward & adjust prob.
98
+ tokens = self.encode(ctx) if i == 0 else [token]
99
+ while len(tokens) > 0:
100
+ out, state = self.model.forward(tokens[:args.chunk_len], state)
101
+ tokens = tokens[args.chunk_len:]
102
+
103
+ for n in args.token_ban:
104
+ out[n] = -float('inf')
105
+ for n in occurrence:
106
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
107
+
108
+ # sampler
109
+ token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
110
+ if token in args.token_stop:
111
+ break
112
+ all_tokens += [token]
113
+ for xxx in occurrence:
114
+ occurrence[xxx] *= args.alpha_decay
115
+
116
+ ttt = self.decode([token])
117
+ www = 1
118
+ if ttt in ' \t0123456789':
119
+ www = 0
120
+ # elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{}οΌŒγ€‚οΌ›β€œβ€οΌšοΌŸοΌοΌˆοΌ‰γ€γ€‘':
121
+ # www = 0.5
122
+ if token not in occurrence:
123
+ occurrence[token] = www
124
+ else:
125
+ occurrence[token] += www
126
+ # print(occurrence) # debug
127
+
128
+ # output
129
+ tmp = self.decode(all_tokens[out_last:])
130
+ if '\ufffd' not in tmp: # is valid utf-8 string?
131
+ if callback:
132
+ callback(tmp)
133
+ out_str += tmp
134
+ out_last = i + 1
135
+ return out_str