Sync with upstream
Browse files- activation/activation_kernels.cu +28 -7
- activation/cuda_compat.h +3 -3
- activation/dispatch_utils.h +48 -0
- build/torch26-cxx98-cu124-x86_64-linux/activation/layers.py +47 -0
- tests/kernels/test_activation.py +17 -2
- torch-ext/activation/__init__.py +5 -0
- torch-ext/activation/layers.py +49 -0
- torch-ext/torch_binding.cpp +3 -0
- torch-ext/torch_binding.h +2 -0
activation/activation_kernels.cu
CHANGED
|
@@ -9,8 +9,16 @@
|
|
| 9 |
|
| 10 |
namespace vllm {
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
// Activation and gating kernel template.
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
__global__ void act_and_mul_kernel(
|
| 15 |
scalar_t* __restrict__ out, // [..., d]
|
| 16 |
const scalar_t* __restrict__ input, // [..., 2, d]
|
|
@@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel(
|
|
| 19 |
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 20 |
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
| 21 |
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
| 22 |
-
out[token_idx * d + idx] = ACT_FN(x
|
| 23 |
}
|
| 24 |
}
|
| 25 |
|
|
@@ -55,16 +63,21 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
|
| 55 |
} // namespace vllm
|
| 56 |
|
| 57 |
// Launch activation and gating kernel.
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
int d = input.size(-1) / 2; \
|
| 60 |
int64_t num_tokens = input.numel() / input.size(-1); \
|
| 61 |
dim3 grid(num_tokens); \
|
| 62 |
dim3 block(std::min(d, 1024)); \
|
|
|
|
|
|
|
|
|
|
| 63 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
| 64 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
| 65 |
VLLM_DISPATCH_FLOATING_TYPES( \
|
| 66 |
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
| 67 |
-
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t
|
| 68 |
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
| 69 |
input.data_ptr<scalar_t>(), d); \
|
| 70 |
});
|
|
@@ -72,19 +85,27 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
|
| 72 |
void silu_and_mul(torch::Tensor& out, // [..., d]
|
| 73 |
torch::Tensor& input) // [..., 2 * d]
|
| 74 |
{
|
| 75 |
-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
}
|
| 77 |
|
| 78 |
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
| 79 |
torch::Tensor& input) // [..., 2 * d]
|
| 80 |
{
|
| 81 |
-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
| 82 |
}
|
| 83 |
|
| 84 |
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
| 85 |
torch::Tensor& input) // [..., 2 * d]
|
| 86 |
{
|
| 87 |
-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
| 88 |
}
|
| 89 |
|
| 90 |
namespace vllm {
|
|
|
|
| 9 |
|
| 10 |
namespace vllm {
|
| 11 |
|
| 12 |
+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
| 13 |
+
bool act_first>
|
| 14 |
+
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
|
| 15 |
+
const scalar_t& y) {
|
| 16 |
+
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
|
| 17 |
+
}
|
| 18 |
// Activation and gating kernel template.
|
| 19 |
+
|
| 20 |
+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
| 21 |
+
bool act_first>
|
| 22 |
__global__ void act_and_mul_kernel(
|
| 23 |
scalar_t* __restrict__ out, // [..., d]
|
| 24 |
const scalar_t* __restrict__ input, // [..., 2, d]
|
|
|
|
| 27 |
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 28 |
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
| 29 |
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
| 30 |
+
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
| 31 |
}
|
| 32 |
}
|
| 33 |
|
|
|
|
| 63 |
} // namespace vllm
|
| 64 |
|
| 65 |
// Launch activation and gating kernel.
|
| 66 |
+
// Use ACT_FIRST (bool) indicating whether to apply the activation function
|
| 67 |
+
// first.
|
| 68 |
+
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
|
| 69 |
int d = input.size(-1) / 2; \
|
| 70 |
int64_t num_tokens = input.numel() / input.size(-1); \
|
| 71 |
dim3 grid(num_tokens); \
|
| 72 |
dim3 block(std::min(d, 1024)); \
|
| 73 |
+
if (num_tokens == 0) { \
|
| 74 |
+
return; \
|
| 75 |
+
} \
|
| 76 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
| 77 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
| 78 |
VLLM_DISPATCH_FLOATING_TYPES( \
|
| 79 |
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
| 80 |
+
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
|
| 81 |
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
| 82 |
input.data_ptr<scalar_t>(), d); \
|
| 83 |
});
|
|
|
|
| 85 |
void silu_and_mul(torch::Tensor& out, // [..., d]
|
| 86 |
torch::Tensor& input) // [..., 2 * d]
|
| 87 |
{
|
| 88 |
+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
void mul_and_silu(torch::Tensor& out, // [..., d]
|
| 92 |
+
torch::Tensor& input) // [..., 2 * d]
|
| 93 |
+
{
|
| 94 |
+
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
|
| 95 |
+
// applies the silu to the latter half of the input.
|
| 96 |
+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
|
| 97 |
}
|
| 98 |
|
| 99 |
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
| 100 |
torch::Tensor& input) // [..., 2 * d]
|
| 101 |
{
|
| 102 |
+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
|
| 103 |
}
|
| 104 |
|
| 105 |
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
| 106 |
torch::Tensor& input) // [..., 2 * d]
|
| 107 |
{
|
| 108 |
+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
|
| 109 |
}
|
| 110 |
|
| 111 |
namespace vllm {
|
activation/cuda_compat.h
CHANGED
|
@@ -4,10 +4,10 @@
|
|
| 4 |
#include <hip/hip_runtime.h>
|
| 5 |
#endif
|
| 6 |
|
| 7 |
-
#
|
| 8 |
-
#define WARP_SIZE
|
| 9 |
#else
|
| 10 |
-
#define WARP_SIZE
|
| 11 |
#endif
|
| 12 |
|
| 13 |
#ifndef USE_ROCM
|
|
|
|
| 4 |
#include <hip/hip_runtime.h>
|
| 5 |
#endif
|
| 6 |
|
| 7 |
+
#if defined(USE_ROCM) && defined(__GFX9__)
|
| 8 |
+
#define WARP_SIZE 64
|
| 9 |
#else
|
| 10 |
+
#define WARP_SIZE 32
|
| 11 |
#endif
|
| 12 |
|
| 13 |
#ifndef USE_ROCM
|
activation/dispatch_utils.h
CHANGED
|
@@ -6,6 +6,11 @@
|
|
| 6 |
|
| 7 |
#include <torch/all.h>
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
| 10 |
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 11 |
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
@@ -14,6 +19,35 @@
|
|
| 14 |
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 15 |
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
| 18 |
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 19 |
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
@@ -31,5 +65,19 @@
|
|
| 31 |
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
| 32 |
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
| 35 |
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
#include <torch/all.h>
|
| 8 |
|
| 9 |
+
// Need a special dispatch case macro since we will nest the FP8 dispatch.
|
| 10 |
+
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
|
| 11 |
+
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
|
| 12 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
|
| 13 |
+
|
| 14 |
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
| 15 |
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 16 |
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
|
|
| 19 |
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 20 |
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
| 21 |
|
| 22 |
+
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
|
| 23 |
+
// A host-based check at runtime will create a preferred FP8 type for ROCm
|
| 24 |
+
// such that the correct kernel is dispatched.
|
| 25 |
+
#ifdef USE_ROCM
|
| 26 |
+
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
| 27 |
+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
| 28 |
+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
|
| 29 |
+
|
| 30 |
+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
| 31 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
| 32 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
| 33 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
| 34 |
+
#else
|
| 35 |
+
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
| 36 |
+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
| 37 |
+
|
| 38 |
+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
| 39 |
+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
| 40 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
| 41 |
+
#endif
|
| 42 |
+
|
| 43 |
+
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
|
| 44 |
+
// See AT_DISPATCH_FP8_CASE above.
|
| 45 |
+
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
| 46 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
| 47 |
+
|
| 48 |
+
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
| 49 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
| 50 |
+
|
| 51 |
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
| 52 |
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 53 |
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
|
|
| 65 |
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
| 66 |
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
| 67 |
|
| 68 |
+
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
|
| 69 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
| 70 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
| 71 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
| 72 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
| 73 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
| 74 |
+
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
|
| 75 |
+
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
|
| 76 |
+
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
|
| 77 |
+
|
| 78 |
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
| 79 |
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
| 80 |
+
|
| 81 |
+
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
|
| 82 |
+
AT_DISPATCH_SWITCH( \
|
| 83 |
+
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
|
build/torch26-cxx98-cu124-x86_64-linux/activation/layers.py
CHANGED
|
@@ -5,6 +5,15 @@ from ._ops import ops
|
|
| 5 |
|
| 6 |
|
| 7 |
class SiluAndMul(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
can_torch_compile: bool = True
|
| 9 |
|
| 10 |
def forward(self, x: torch.Tensor):
|
|
@@ -14,8 +23,35 @@ class SiluAndMul(nn.Module):
|
|
| 14 |
ops.silu_and_mul(out, x)
|
| 15 |
return out
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
class GeluAndMul(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
can_torch_compile: bool = True
|
| 20 |
|
| 21 |
def forward(self, x: torch.Tensor):
|
|
@@ -38,6 +74,17 @@ class GeluTanhAndMul(nn.Module):
|
|
| 38 |
|
| 39 |
|
| 40 |
class FatreluAndMul(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
can_torch_compile: bool = True
|
| 42 |
|
| 43 |
def __init__(self, threshold: float = 0.0):
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class SiluAndMul(nn.Module):
|
| 8 |
+
"""An activation function for SwiGLU.
|
| 9 |
+
|
| 10 |
+
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
| 11 |
+
|
| 12 |
+
Shapes:
|
| 13 |
+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
| 14 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
can_torch_compile: bool = True
|
| 18 |
|
| 19 |
def forward(self, x: torch.Tensor):
|
|
|
|
| 23 |
ops.silu_and_mul(out, x)
|
| 24 |
return out
|
| 25 |
|
| 26 |
+
class MulAndSilu(CustomOp):
|
| 27 |
+
"""An activation function for SwiGLU.
|
| 28 |
+
|
| 29 |
+
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
|
| 30 |
+
|
| 31 |
+
Shapes:
|
| 32 |
+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
| 33 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
can_torch_compile: bool = True
|
| 37 |
+
|
| 38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
d = x.shape[-1] // 2
|
| 40 |
+
output_shape = (x.shape[:-1] + (d, ))
|
| 41 |
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
| 42 |
+
self.mul_and_silu(out, x)
|
| 43 |
+
return out
|
| 44 |
|
| 45 |
class GeluAndMul(nn.Module):
|
| 46 |
+
"""An activation function for GeGLU.
|
| 47 |
+
|
| 48 |
+
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
| 49 |
+
|
| 50 |
+
Shapes:
|
| 51 |
+
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
| 52 |
+
return: (batch_size, seq_len, d) or (num_tokens, d)
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
can_torch_compile: bool = True
|
| 56 |
|
| 57 |
def forward(self, x: torch.Tensor):
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
class FatreluAndMul(nn.Module):
|
| 77 |
+
"""An activation function for FATReLU.
|
| 78 |
+
|
| 79 |
+
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
| 80 |
+
d = x.shape[-1] // 2.
|
| 81 |
+
This is used in openbmb/MiniCPM-S-1B-sft.
|
| 82 |
+
|
| 83 |
+
Shapes:
|
| 84 |
+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
| 85 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
can_torch_compile: bool = True
|
| 89 |
|
| 90 |
def __init__(self, threshold: float = 0.0):
|
tests/kernels/test_activation.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
import random
|
| 3 |
from typing import Type
|
|
@@ -43,12 +46,19 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
|
| 43 |
return F.silu(x[..., :d]) * x[..., d:]
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
|
| 47 |
d = x.shape[-1] // 2
|
| 48 |
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
| 49 |
|
| 50 |
|
| 51 |
-
@pytest.mark.parametrize(
|
|
|
|
|
|
|
| 52 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 53 |
@pytest.mark.parametrize("d", D)
|
| 54 |
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@@ -67,11 +77,16 @@ def test_act_and_mul(
|
|
| 67 |
torch.manual_seed(seed)
|
| 68 |
torch.set_default_device(device)
|
| 69 |
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
| 70 |
-
if activation_name == "
|
| 71 |
torch_fn = silu_and_mul
|
| 72 |
fn = activation.silu_and_mul
|
| 73 |
op = activation.ops.silu_and_mul
|
| 74 |
layer = activation.layers.SiluAndMul()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
elif activation_name == "gelu":
|
| 76 |
torch_fn = lambda x: gelu_and_mul(x, "none")
|
| 77 |
fn = activation.gelu_and_mul
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
|
| 4 |
import math
|
| 5 |
import random
|
| 6 |
from typing import Type
|
|
|
|
| 46 |
return F.silu(x[..., :d]) * x[..., d:]
|
| 47 |
|
| 48 |
|
| 49 |
+
def mul_and_silu(x: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
d = x.shape[-1] // 2
|
| 51 |
+
return x[..., :d] * F.silu(x[..., d:])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
|
| 55 |
d = x.shape[-1] // 2
|
| 56 |
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
| 57 |
|
| 58 |
|
| 59 |
+
@pytest.mark.parametrize(
|
| 60 |
+
"activation_name", ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]
|
| 61 |
+
)
|
| 62 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 63 |
@pytest.mark.parametrize("d", D)
|
| 64 |
@pytest.mark.parametrize("dtype", DTYPES)
|
|
|
|
| 77 |
torch.manual_seed(seed)
|
| 78 |
torch.set_default_device(device)
|
| 79 |
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
| 80 |
+
if activation_name == "silu_and_mul":
|
| 81 |
torch_fn = silu_and_mul
|
| 82 |
fn = activation.silu_and_mul
|
| 83 |
op = activation.ops.silu_and_mul
|
| 84 |
layer = activation.layers.SiluAndMul()
|
| 85 |
+
elif activation_name == "mul_and_silu":
|
| 86 |
+
torch_fn = mul_and_silu
|
| 87 |
+
fn = activation.mul_and_silu
|
| 88 |
+
op = activation.ops.mul_and_silu
|
| 89 |
+
layer = activation.layers.MulAndSilu()
|
| 90 |
elif activation_name == "gelu":
|
| 91 |
torch_fn = lambda x: gelu_and_mul(x, "none")
|
| 92 |
fn = activation.gelu_and_mul
|
torch-ext/activation/__init__.py
CHANGED
|
@@ -10,6 +10,11 @@ def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
| 10 |
return out
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 14 |
ops.gelu_and_mul(out, x)
|
| 15 |
return out
|
|
|
|
| 10 |
return out
|
| 11 |
|
| 12 |
|
| 13 |
+
def mul_and_silu(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 14 |
+
ops.mul_and_silu(out, x)
|
| 15 |
+
return out
|
| 16 |
+
|
| 17 |
+
|
| 18 |
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 19 |
ops.gelu_and_mul(out, x)
|
| 20 |
return out
|
torch-ext/activation/layers.py
CHANGED
|
@@ -5,6 +5,15 @@ from ._ops import ops
|
|
| 5 |
|
| 6 |
|
| 7 |
class SiluAndMul(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
can_torch_compile: bool = True
|
| 9 |
|
| 10 |
def forward(self, x: torch.Tensor):
|
|
@@ -15,7 +24,36 @@ class SiluAndMul(nn.Module):
|
|
| 15 |
return out
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class GeluAndMul(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
can_torch_compile: bool = True
|
| 20 |
|
| 21 |
def forward(self, x: torch.Tensor):
|
|
@@ -38,6 +76,17 @@ class GeluTanhAndMul(nn.Module):
|
|
| 38 |
|
| 39 |
|
| 40 |
class FatreluAndMul(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
can_torch_compile: bool = True
|
| 42 |
|
| 43 |
def __init__(self, threshold: float = 0.0):
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class SiluAndMul(nn.Module):
|
| 8 |
+
"""An activation function for SwiGLU.
|
| 9 |
+
|
| 10 |
+
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
| 11 |
+
|
| 12 |
+
Shapes:
|
| 13 |
+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
| 14 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
can_torch_compile: bool = True
|
| 18 |
|
| 19 |
def forward(self, x: torch.Tensor):
|
|
|
|
| 24 |
return out
|
| 25 |
|
| 26 |
|
| 27 |
+
class MulAndSilu(nn.Module):
|
| 28 |
+
"""An activation function for SwiGLU.
|
| 29 |
+
|
| 30 |
+
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
|
| 31 |
+
|
| 32 |
+
Shapes:
|
| 33 |
+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
| 34 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
can_torch_compile: bool = True
|
| 38 |
+
|
| 39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
d = x.shape[-1] // 2
|
| 41 |
+
output_shape = x.shape[:-1] + (d,)
|
| 42 |
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
| 43 |
+
ops.mul_and_silu(out, x)
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
class GeluAndMul(nn.Module):
|
| 48 |
+
"""An activation function for GeGLU.
|
| 49 |
+
|
| 50 |
+
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
| 51 |
+
|
| 52 |
+
Shapes:
|
| 53 |
+
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
| 54 |
+
return: (batch_size, seq_len, d) or (num_tokens, d)
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
can_torch_compile: bool = True
|
| 58 |
|
| 59 |
def forward(self, x: torch.Tensor):
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
class FatreluAndMul(nn.Module):
|
| 79 |
+
"""An activation function for FATReLU.
|
| 80 |
+
|
| 81 |
+
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
| 82 |
+
d = x.shape[-1] // 2.
|
| 83 |
+
This is used in openbmb/MiniCPM-S-1B-sft.
|
| 84 |
+
|
| 85 |
+
Shapes:
|
| 86 |
+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
| 87 |
+
return: (num_tokens, d) or (batch_size, seq_len, d)
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
can_torch_compile: bool = True
|
| 91 |
|
| 92 |
def __init__(self, threshold: float = 0.0):
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -9,6 +9,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 9 |
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
| 10 |
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
// Activation function used in GeGLU with `none` approximation.
|
| 13 |
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
| 14 |
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
|
|
|
| 9 |
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
| 10 |
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
| 11 |
|
| 12 |
+
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
| 13 |
+
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
| 14 |
+
|
| 15 |
// Activation function used in GeGLU with `none` approximation.
|
| 16 |
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
| 17 |
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
torch-ext/torch_binding.h
CHANGED
|
@@ -4,6 +4,8 @@
|
|
| 4 |
|
| 5 |
void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
|
| 6 |
|
|
|
|
|
|
|
| 7 |
void gelu_and_mul(torch::Tensor &out, torch::Tensor &input);
|
| 8 |
|
| 9 |
void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input);
|
|
|
|
| 4 |
|
| 5 |
void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
|
| 6 |
|
| 7 |
+
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
| 8 |
+
|
| 9 |
void gelu_and_mul(torch::Tensor &out, torch::Tensor &input);
|
| 10 |
|
| 11 |
void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input);
|