File size: 11,970 Bytes
25b4ce2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
/*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include "gemv_cuda.h"
#include "../dequantize.cuh"
#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128
static inline __device__ float to_float(half src)
{
return __half2float(src);
}
static inline __device__ float to_float(float src)
{
return src;
}
static inline __device__ half to_half(float src)
{
return __float2half(src);
}
static inline __device__ half to_half(half src)
{
return src;
}
// Reduce sum within the warp using the tree reduction algorithm.
template <int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)[Num * 4])
{
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = to_float(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i)
{
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
{
#pragma unroll
for (int i = 0; i < Num; ++i)
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
};
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
}
template <int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(
const half* inputs, const uint32_t* weight, const half* scales, const half* zeros, half* outputs,
const int IC, const int OC)
{
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread;
// assert(MEM_ACCESS_SIZE == 128);
static constexpr int kShuffleSize = 32;
static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4;
half local_inputs[kElemsPerThread];
uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
half half_weight_buffer[kElemsPerThread];
half dequantized_weight[kElemsPerThread * NPerBlock];
half local_scale[NPerBlock];
half local_scaled_zeros[NPerBlock];
half psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = to_half(0.f);
extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride
+ (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible
const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const half* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const half* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const half* inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)
{
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx)
{
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4*)(local_qweights)) =
*((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
// Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_fp16x2(*reinterpret_cast<half2 *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
}
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i)
{
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j)
{
half2 w =
*reinterpret_cast<half2*>(
half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile
);
w = __hfma2(w, __half2half2(local_scale[idx]), __half2half2(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0)
* NPerBlock + idx]
= w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1)
* NPerBlock + idx]
= w.y;
}
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)
{
const half* local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
// load activation, 8 halves (128 bits) / step.
*((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8));
}
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x)
{
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y)
{
*reinterpret_cast<half2*>(psum + batch_idx * NPerBlock + x * 2)
= __hfma2(*reinterpret_cast<half2*>(dequantized_weight + y * NPerBlock + x * 2),
__half2half2(local_inputs[y]),
*reinterpret_cast<half2*>(psum + batch_idx * NPerBlock + x * 2));
}
}
}
inputs_ptr += act_forward_step;
scale_ptr += scale_forward_step;
zeros_ptr += scale_forward_step;
}
warp_reduce<Num, WARP_SIZE>(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)
{
int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j)
{
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = to_half(acc);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch::Tensor gemv_forward_cuda_decode(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
std::vector<int64_t> output_shape = _in_feats.sizes().vec();
output_shape.back() = n;
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half*>(_zeros.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options);
half * out_feats = reinterpret_cast<half *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
// if (group_size == 64)
// {
// gemv_kernel_g64<<<num_blocks, num_threads>>>(
// // pointers
// in_feats, kernel, zeros, scaling_factors, out_feats,
// // constants
// num_in_channels, num_out_channels
// );
// }
if (group_size == 128)
{
switch (m)
{
case 1:
gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 2:
gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 3:
gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 4:
gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 5:
gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 6:
gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 7:
gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
default:
throw std::runtime_error("Unsupported batch size for gemv kernel.\n");
}
}
else
{
throw std::runtime_error("Unsupported group size for gemv kernel.\n");
}
return _out_feats;
}
|