Spaces:
Running
on
T4
Running
on
T4
| typedef at::Half fp16; | |
| __half *cast(fp16 *ptr) { | |
| return reinterpret_cast<__half *>(ptr); | |
| } | |
| template <typename F> | |
| __global__ void kernel_wkv_forward(const int B, const int T, const int C, | |
| const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, | |
| F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) { | |
| const int idx = blockIdx.x * blockDim.x + threadIdx.x; | |
| const int _b = idx / C; | |
| const int _c = idx % C; | |
| const int _offset = _b * T * C + _c; | |
| const int _state_offset = _b * C + _c; | |
| float u = _u[_c]; | |
| float w = _w[_c]; | |
| const F *__restrict__ const k = _k + _offset; | |
| const F *__restrict__ const v = _v + _offset; | |
| F *__restrict__ const y = _y + _offset; | |
| float aa = _aa[_state_offset]; | |
| float bb = _bb[_state_offset]; | |
| float pp = _pp[_state_offset]; | |
| for (int i = 0; i < T; i++) { | |
| const int ii = i * C; | |
| const float kk = float(k[ii]); | |
| const float vv = float(v[ii]); | |
| float ww = u + kk; | |
| float p = max(pp, ww); | |
| float e1 = exp(pp - p); | |
| float e2 = exp(ww - p); | |
| y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2)); | |
| ww = w + pp; | |
| p = max(ww, kk); | |
| e1 = exp(ww - p); | |
| e2 = exp(kk - p); | |
| aa = e1 * aa + e2 * vv; | |
| bb = e1 * bb + e2; | |
| pp = p; | |
| } | |
| _aa[_state_offset] = aa; | |
| _bb[_state_offset] = bb; | |
| _pp[_state_offset] = pp; | |
| } | |
| template <typename F> | |
| void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) { | |
| dim3 threadsPerBlock( min(C, 32) ); | |
| assert(B * C % threadsPerBlock.x == 0); | |
| dim3 numBlocks(B * C / threadsPerBlock.x); | |
| kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp); | |
| } | |
| template void cuda_wkv_forward<fp16>( | |
| int B, int T, int C, | |
| float *w, float *u, fp16 *k, fp16 *v, fp16 *y, | |
| float *aa, float *bb, float *pp); | |
| template void cuda_wkv_forward<float>( | |
| int B, int T, int C, | |
| float *w, float *u, float *k, float *v, float *y, | |
| float *aa, float *bb, float *pp); | |
| __global__ void kernel_mm_seq_fp32i8( | |
| const int B, const int N, const int M, | |
| const float *__restrict__ const x, const int x_stride, | |
| const uint8_t *__restrict__ const w, const int w_stride, | |
| const float *__restrict__ const mx, | |
| const float *__restrict__ const rx, | |
| const float *__restrict__ const my, | |
| const float *__restrict__ const ry, | |
| float *__restrict__ const y, const int y_stride) { | |
| const int i = blockIdx.x * blockDim.x + threadIdx.x; | |
| const int k = blockIdx.y * blockDim.y + threadIdx.y; | |
| if (i < B && k < M) { | |
| float y_local = 0; | |
| for (int j = 0; j < N; ++j) { | |
| y_local += x[i * x_stride + j] * ( | |
| (float(w[j * w_stride + k]) + 0.5f) | |
| * rx[k] * ry[j] + mx[k] + my[j] | |
| ); | |
| } | |
| y[i * y_stride + k] = y_local; | |
| } | |
| } | |
| template <typename F> | |
| void cuda_mm8_seq(int B, int N, int M, | |
| F *x, int x_stride, | |
| uint8_t *w, int w_stride, | |
| F *mx, F *rx, | |
| F *my, F *ry, | |
| F *y, int y_stride); | |
| template <> | |
| void cuda_mm8_seq<float>(int B, int N, int M, | |
| float *x, int x_stride, | |
| uint8_t *w, int w_stride, | |
| float *mx, float *rx, | |
| float *my, float *ry, | |
| float *y, int y_stride) { | |
| dim3 blockSize(1, 128); | |
| dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); | |
| kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>( | |
| B, N, M, x, x_stride, w, w_stride, | |
| mx, rx, my, ry, y, y_stride); | |
| } | |
| __global__ void kernel_mm_seq_fp16i8( | |
| const int B, const int N, const int M, | |
| const __half *__restrict__ const x, const int x_stride, | |
| const uint8_t *__restrict__ const w, const int w_stride, | |
| const __half *__restrict__ const mx, | |
| const __half *__restrict__ const rx, | |
| const __half *__restrict__ const my, | |
| const __half *__restrict__ const ry, | |
| __half *__restrict__ const y, const int y_stride) { | |
| const int i = blockIdx.x * blockDim.x + threadIdx.x; | |
| const int k = blockIdx.y * blockDim.y + threadIdx.y; | |
| if (i < B && k < M) { | |
| float y_local = 0; | |
| for (int j = 0; j < N; ++j) { | |
| y_local += __half2float(x[i * x_stride + j]) * ( | |
| (float(w[j * w_stride + k]) + 0.5f) | |
| * __half2float(rx[k]) * __half2float(ry[j]) | |
| + __half2float(mx[k]) + __half2float(my[j]) | |
| ); | |
| } | |
| y[i * y_stride + k] = __float2half(y_local); | |
| } | |
| } | |
| template <> | |
| void cuda_mm8_seq<fp16>(int B, int N, int M, | |
| fp16 *x, int x_stride, | |
| uint8_t *w, int w_stride, | |
| fp16 *mx, fp16 *rx, | |
| fp16 *my, fp16 *ry, | |
| fp16 *y, int y_stride) { | |
| dim3 blockSize(1, 128); | |
| dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); | |
| kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>( | |
| B, N, M, cast(x), x_stride, w, w_stride, | |
| cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride); | |
| } | |
| __global__ void kernel_mm_one_fp32i8( | |
| const int N, const int M, | |
| const float *__restrict__ const x, | |
| const uint8_t *__restrict__ const w, const int w_stride, | |
| const float *__restrict__ const mx, | |
| const float *__restrict__ const rx, | |
| const float *__restrict__ const my, | |
| const float *__restrict__ const ry, | |
| float *__restrict__ const y) { | |
| const int k = blockIdx.y * blockDim.y + threadIdx.y; | |
| const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); | |
| const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); | |
| if (k < M) { | |
| float y_local = 0; | |
| for (int j = j0; j < j1; ++j) { | |
| y_local += x[j] * ( | |
| (float(w[j * w_stride + k]) + 0.5f) | |
| * rx[k] * ry[j] + mx[k] + my[j] | |
| ); | |
| } | |
| atomicAdd(&y[k], y_local); | |
| } | |
| } | |
| template <typename F> | |
| void cuda_mm8_one(int N, int M, | |
| F *x, | |
| uint8_t *w, int w_stride, | |
| F *mx, F *rx, | |
| F *my, F *ry, | |
| float *y); | |
| template <> | |
| void cuda_mm8_one<float>(int N, int M, | |
| float *x, | |
| uint8_t *w, int w_stride, | |
| float *mx, float *rx, | |
| float *my, float *ry, | |
| float *y) { | |
| dim3 blockSize(1, MM8_ONE_TILE); | |
| dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); | |
| kernel_mm_one_fp32i8<<<gridSize, blockSize>>>( | |
| N, M, x, w, w_stride, | |
| mx, rx, my, ry, y); | |
| } | |
| __global__ void kernel_mm_one_fp16i8( | |
| const int N, const int M, | |
| const __half *__restrict__ const x, | |
| const uint8_t *__restrict__ const w, const int w_stride, | |
| const __half *__restrict__ const mx, | |
| const __half *__restrict__ const rx, | |
| const __half *__restrict__ const my, | |
| const __half *__restrict__ const ry, | |
| float *__restrict__ const y) { | |
| const int k = blockIdx.y * blockDim.y + threadIdx.y; | |
| const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); | |
| const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); | |
| if (k < M) { | |
| float y_local = 0; | |
| for (int j = j0; j < j1; ++j) { | |
| y_local += __half2float(x[j]) * ( | |
| (float(w[j * w_stride + k]) + 0.5f) | |
| * __half2float(rx[k]) * __half2float(ry[j]) | |
| + __half2float(mx[k]) + __half2float(my[j]) | |
| ); | |
| } | |
| atomicAdd(&y[k], y_local); | |
| } | |
| } | |
| template <> | |
| void cuda_mm8_one<fp16>(int N, int M, | |
| fp16 *x, | |
| uint8_t *w, int w_stride, | |
| fp16 *mx, fp16 *rx, | |
| fp16 *my, fp16 *ry, | |
| float *y) { | |
| dim3 blockSize(1, MM8_ONE_TILE); | |
| dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); | |
| kernel_mm_one_fp16i8<<<gridSize, blockSize>>>( | |
| N, M, cast(x), w, w_stride, | |
| cast(mx), cast(rx), cast(my), cast(ry), y); | |
| } | |