Upload custom kernels
Browse files- build.toml +15 -0
- flake.nix +13 -0
- rmsnorm_kernel/rmsnorm.cu +163 -0
- torch-ext/rmsnorm_kernel/__init__.py +21 -0
- torch-ext/torch_binding.cpp +11 -0
- torch-ext/torch_binding.h +5 -0
build.toml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "rmsnorm_kernel"
|
| 3 |
+
|
| 4 |
+
[torch]
|
| 5 |
+
src = [
|
| 6 |
+
"torch-ext/torch_binding.cpp",
|
| 7 |
+
"torch-ext/torch_binding.h"
|
| 8 |
+
]
|
| 9 |
+
|
| 10 |
+
[kernel.rmsnorm_kernel]
|
| 11 |
+
src = [
|
| 12 |
+
"rmsnorm_kernel/rmsnorm.cu",
|
| 13 |
+
]
|
| 14 |
+
depends = [ "torch"]
|
| 15 |
+
cuda-capabilities = [ "12.3" ]
|
flake.nix
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for Torch kernel extension";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs = { self, kernel-builder, }:
|
| 9 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 10 |
+
path = ./.;
|
| 11 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 12 |
+
};
|
| 13 |
+
}
|
rmsnorm_kernel/rmsnorm.cu
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include <thrust/execution_policy.h>
|
| 3 |
+
#include <thrust/for_each.h>
|
| 4 |
+
#include <thrust/iterator/counting_iterator.h>
|
| 5 |
+
#include <cmath>
|
| 6 |
+
#include <thrust/device_vector.h>
|
| 7 |
+
#include <thrust/copy.h>
|
| 8 |
+
#include <iostream>
|
| 9 |
+
#include <iomanip> // For formatting output
|
| 10 |
+
|
| 11 |
+
const float EPS = 1e-5f;
|
| 12 |
+
|
| 13 |
+
// CPU implementation of RMSNorm
|
| 14 |
+
torch::Tensor rmsnorm_forward_cpu(torch::Tensor x, torch::Tensor gamma) {
|
| 15 |
+
int B = x.size(0), S = x.size(1), H = x.size(2);
|
| 16 |
+
auto out = torch::empty_like(x);
|
| 17 |
+
|
| 18 |
+
auto x_accessor = x.accessor<float, 3>();
|
| 19 |
+
auto gamma_accessor = gamma.accessor<float, 1>();
|
| 20 |
+
auto out_accessor = out.accessor<float, 3>();
|
| 21 |
+
|
| 22 |
+
// Process each row
|
| 23 |
+
for (int b = 0; b < B; ++b) {
|
| 24 |
+
for (int s = 0; s < S; ++s) {
|
| 25 |
+
// Calculate root mean square
|
| 26 |
+
float sum_sq = 0.0f;
|
| 27 |
+
for (int h = 0; h < H; ++h) {
|
| 28 |
+
float val = x_accessor[b][s][h];
|
| 29 |
+
sum_sq += val * val;
|
| 30 |
+
}
|
| 31 |
+
float rms = std::sqrt(sum_sq / H + EPS);
|
| 32 |
+
|
| 33 |
+
// Normalize and scale
|
| 34 |
+
for (int h = 0; h < H; ++h) {
|
| 35 |
+
out_accessor[b][s][h] = (x_accessor[b][s][h] / rms) * gamma_accessor[h];
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
return out;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
struct RmsnormFunctor {
|
| 44 |
+
const float* x;
|
| 45 |
+
const float* gamma;
|
| 46 |
+
float* out;
|
| 47 |
+
int hidden_dim;
|
| 48 |
+
|
| 49 |
+
RmsnormFunctor(const float* x_, const float* gamma_, float* out_, int h_)
|
| 50 |
+
: x(x_), gamma(gamma_), out(out_), hidden_dim(h_) {}
|
| 51 |
+
|
| 52 |
+
__device__
|
| 53 |
+
void operator()(int row_idx) {
|
| 54 |
+
const float* row_x = x + row_idx * hidden_dim;
|
| 55 |
+
float* row_out = out + row_idx * hidden_dim;
|
| 56 |
+
|
| 57 |
+
float sum_sq = 0.0f;
|
| 58 |
+
for (int i = 0; i < hidden_dim; ++i)
|
| 59 |
+
sum_sq += row_x[i] * row_x[i];
|
| 60 |
+
|
| 61 |
+
float rms = sqrtf(sum_sq / hidden_dim + EPS);
|
| 62 |
+
|
| 63 |
+
for (int i = 0; i < hidden_dim; ++i)
|
| 64 |
+
row_out[i] = (row_x[i] / rms) * gamma[i];
|
| 65 |
+
}
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
torch::Tensor rmsnorm_forward(torch::Tensor x, torch::Tensor gamma) {
|
| 69 |
+
int B = x.size(0), S = x.size(1), H = x.size(2);
|
| 70 |
+
int rows = B * S;
|
| 71 |
+
|
| 72 |
+
// Create output tensor with same shape as input
|
| 73 |
+
auto out = torch::empty_like(x);
|
| 74 |
+
|
| 75 |
+
const float* x_ptr = x.data_ptr<float>();
|
| 76 |
+
const float* gamma_ptr = gamma.data_ptr<float>();
|
| 77 |
+
float* out_ptr = out.data_ptr<float>();
|
| 78 |
+
|
| 79 |
+
thrust::counting_iterator<int> iter(0);
|
| 80 |
+
thrust::for_each(
|
| 81 |
+
thrust::device,
|
| 82 |
+
iter, iter + rows,
|
| 83 |
+
RmsnormFunctor(x_ptr, gamma_ptr, out_ptr, H)
|
| 84 |
+
);
|
| 85 |
+
|
| 86 |
+
return out;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// int main() {
|
| 90 |
+
// int B = 2, S = 2, H = 4;
|
| 91 |
+
|
| 92 |
+
// // Create tensors directly on CPU first
|
| 93 |
+
// auto options_cpu = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
|
| 94 |
+
|
| 95 |
+
// // Initialize with CPU data
|
| 96 |
+
// torch::Tensor x_cpu = torch::tensor({
|
| 97 |
+
// {
|
| 98 |
+
// {1.0f, 2.0f, 3.0f, 4.0f},
|
| 99 |
+
// {5.0f, 6.0f, 7.0f, 8.0f}
|
| 100 |
+
// },
|
| 101 |
+
// {
|
| 102 |
+
// {2.0f, 2.0f, 2.0f, 2.0f},
|
| 103 |
+
// {9.0f, 10.0f, 11.0f, 12.0f}
|
| 104 |
+
// }
|
| 105 |
+
// }, options_cpu);
|
| 106 |
+
|
| 107 |
+
// torch::Tensor gamma_cpu = torch::tensor({1.0f, 1.0f, 1.0f, 1.0f}, options_cpu);
|
| 108 |
+
|
| 109 |
+
// // Run CPU version
|
| 110 |
+
// std::cout << "===== CPU IMPLEMENTATION RESULTS =====" << std::endl;
|
| 111 |
+
// torch::Tensor out_cpu = rmsnorm_forward_cpu(x_cpu, gamma_cpu);
|
| 112 |
+
// auto cpu_accessor = out_cpu.accessor<float, 3>();
|
| 113 |
+
|
| 114 |
+
// for (int b = 0; b < B; ++b) {
|
| 115 |
+
// for (int s = 0; s < S; ++s) {
|
| 116 |
+
// std::cout << "Row " << (b * S + s) << ": ";
|
| 117 |
+
// for (int h = 0; h < H; ++h) {
|
| 118 |
+
// std::cout << std::fixed << std::setprecision(6) << cpu_accessor[b][s][h] << " ";
|
| 119 |
+
// }
|
| 120 |
+
// std::cout << "\n";
|
| 121 |
+
// }
|
| 122 |
+
// }
|
| 123 |
+
|
| 124 |
+
// // Move tensors to CUDA for GPU version
|
| 125 |
+
// auto cuda_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
| 126 |
+
// torch::Tensor x_cuda = x_cpu.to(torch::kCUDA);
|
| 127 |
+
// torch::Tensor gamma_cuda = gamma_cpu.to(torch::kCUDA);
|
| 128 |
+
|
| 129 |
+
// // Call the CUDA kernel wrapper
|
| 130 |
+
// std::cout << "\n===== GPU IMPLEMENTATION RESULTS =====" << std::endl;
|
| 131 |
+
// torch::Tensor out_cuda = rmsnorm_forward(x_cuda, gamma_cuda);
|
| 132 |
+
|
| 133 |
+
// // Copy result back to CPU and print
|
| 134 |
+
// auto gpu_result_on_cpu = out_cuda.cpu();
|
| 135 |
+
// auto gpu_accessor = gpu_result_on_cpu.accessor<float, 3>();
|
| 136 |
+
|
| 137 |
+
// for (int b = 0; b < B; ++b) {
|
| 138 |
+
// for (int s = 0; s < S; ++s) {
|
| 139 |
+
// std::cout << "Row " << (b * S + s) << ": ";
|
| 140 |
+
// for (int h = 0; h < H; ++h) {
|
| 141 |
+
// std::cout << std::fixed << std::setprecision(6) << gpu_accessor[b][s][h] << " ";
|
| 142 |
+
// }
|
| 143 |
+
// std::cout << "\n";
|
| 144 |
+
// }
|
| 145 |
+
// }
|
| 146 |
+
|
| 147 |
+
// // Check if results match
|
| 148 |
+
// std::cout << "\n===== COMPARISON =====" << std::endl;
|
| 149 |
+
// float max_diff = 0.0f;
|
| 150 |
+
// for (int b = 0; b < B; ++b) {
|
| 151 |
+
// for (int s = 0; s < S; ++s) {
|
| 152 |
+
// for (int h = 0; h < H; ++h) {
|
| 153 |
+
// float diff = std::abs(cpu_accessor[b][s][h] - gpu_accessor[b][s][h]);
|
| 154 |
+
// max_diff = std::max(max_diff, diff);
|
| 155 |
+
// }
|
| 156 |
+
// }
|
| 157 |
+
// }
|
| 158 |
+
// std::cout << "Maximum difference between CPU and GPU results: "
|
| 159 |
+
// << std::scientific << max_diff << std::endl;
|
| 160 |
+
// std::cout << (max_diff < 1e-5 ? "PASSED: Results match!" : "FAILED: Results don't match!") << std::endl;
|
| 161 |
+
|
| 162 |
+
// return 0;
|
| 163 |
+
// }
|
torch-ext/rmsnorm_kernel/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from ._ops import ops
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LlamaRMSNorm(nn.Module):
|
| 8 |
+
weight: torch.Tensor
|
| 9 |
+
variance_epsilon: float
|
| 10 |
+
|
| 11 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
return ops.rmsnorm_forward(
|
| 13 |
+
hidden_states,
|
| 14 |
+
self.weight,
|
| 15 |
+
bias=None,
|
| 16 |
+
residual=None,
|
| 17 |
+
eps=self.variance_epsilon,
|
| 18 |
+
dropout_p=0.0,
|
| 19 |
+
prenorm=False,
|
| 20 |
+
residual_in_fp32=False,
|
| 21 |
+
)
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
#include "torch_binding.h"
|
| 5 |
+
|
| 6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
ops.def("rmsnorm_forward(Tensor input, Tensor gamma) -> ()");
|
| 8 |
+
ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward);
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void rmsnorm_forward(torch::Tensor const &input, torch::Tensor const &gamma);
|