|
#include <ATen/cuda/CUDAContext.h> |
|
#include <torch/extension.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
|
|
#define VLLM_LDG(arg) *(arg) |
|
|
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ |
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ |
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ |
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) |
|
|
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ |
|
AT_DISPATCH_SWITCH( \ |
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) |
|
|
|
template<typename T> |
|
__device__ __forceinline__ T silu(const T& x) { |
|
|
|
return (T) (((float) x) / (1.0f + expf((float) -x))); |
|
} |
|
|
|
template<typename scalar_t> |
|
__global__ void silu_and_mul_kernel( |
|
scalar_t* __restrict__ out, |
|
const scalar_t* __restrict__ input, |
|
const int d) { |
|
const int64_t token_idx = blockIdx.x; |
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { |
|
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); |
|
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); |
|
out[token_idx * d + idx] = silu(x) * y; |
|
} |
|
} |
|
|
|
|
|
void silu_and_mul( |
|
torch::Tensor& out, |
|
torch::Tensor& input) |
|
{ |
|
int64_t num_tokens = input.numel() / input.size(-1); |
|
int d = input.size(-1) / 2; |
|
|
|
dim3 grid(num_tokens); |
|
dim3 block(std::min(d, 1024)); |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); |
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
|
VLLM_DISPATCH_FLOATING_TYPES( |
|
input.scalar_type(), |
|
"silu_and_mul_kernel", |
|
[&] { |
|
silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>( |
|
out.data_ptr<scalar_t>(), |
|
input.data_ptr<scalar_t>(), |
|
d); |
|
}); |
|
} |