| const float EPS = 1e-5f; | |
| // CPU implementation of RMSNorm | |
| torch::Tensor rmsnorm_forward_cpu(torch::Tensor x, torch::Tensor gamma) { | |
| int B = x.size(0), S = x.size(1), H = x.size(2); | |
| auto out = torch::empty_like(x); | |
| auto x_accessor = x.accessor<float, 3>(); | |
| auto gamma_accessor = gamma.accessor<float, 1>(); | |
| auto out_accessor = out.accessor<float, 3>(); | |
| // Process each row | |
| for (int b = 0; b < B; ++b) { | |
| for (int s = 0; s < S; ++s) { | |
| // Calculate root mean square | |
| float sum_sq = 0.0f; | |
| for (int h = 0; h < H; ++h) { | |
| float val = x_accessor[b][s][h]; | |
| sum_sq += val * val; | |
| } | |
| float rms = std::sqrt(sum_sq / H + EPS); | |
| // Normalize and scale | |
| for (int h = 0; h < H; ++h) { | |
| out_accessor[b][s][h] = (x_accessor[b][s][h] / rms) * gamma_accessor[h]; | |
| } | |
| } | |
| } | |
| return out; | |
| } | |
| struct RmsnormFunctor { | |
| const float* x; | |
| const float* gamma; | |
| float* out; | |
| int hidden_dim; | |
| RmsnormFunctor(const float* x_, const float* gamma_, float* out_, int h_) | |
| : x(x_), gamma(gamma_), out(out_), hidden_dim(h_) {} | |
| __device__ | |
| void operator()(int row_idx) { | |
| const float* row_x = x + row_idx * hidden_dim; | |
| float* row_out = out + row_idx * hidden_dim; | |
| float sum_sq = 0.0f; | |
| for (int i = 0; i < hidden_dim; ++i) | |
| sum_sq += row_x[i] * row_x[i]; | |
| float rms = sqrtf(sum_sq / hidden_dim + EPS); | |
| for (int i = 0; i < hidden_dim; ++i) | |
| row_out[i] = (row_x[i] / rms) * gamma[i]; | |
| } | |
| }; | |
| torch::Tensor rmsnorm_forward(torch::Tensor &x, torch::Tensor &gamma) { | |
| int B = x.size(0), S = x.size(1), H = x.size(2); | |
| int rows = B * S; | |
| // Create output tensor with same shape as input | |
| auto out = torch::empty_like(x); | |
| const float* x_ptr = x.data_ptr<float>(); | |
| const float* gamma_ptr = gamma.data_ptr<float>(); | |
| float* out_ptr = out.data_ptr<float>(); | |
| thrust::counting_iterator<int> iter(0); | |
| thrust::for_each( | |
| thrust::device, | |
| iter, iter + rows, | |
| RmsnormFunctor(x_ptr, gamma_ptr, out_ptr, H) | |
| ); | |
| return out; | |
| } | |
| // int main() { | |
| // int B = 2, S = 2, H = 4; | |
| // // Create tensors directly on CPU first | |
| // auto options_cpu = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU); | |
| // // Initialize with CPU data | |
| // torch::Tensor x_cpu = torch::tensor({ | |
| // { | |
| // {1.0f, 2.0f, 3.0f, 4.0f}, | |
| // {5.0f, 6.0f, 7.0f, 8.0f} | |
| // }, | |
| // { | |
| // {2.0f, 2.0f, 2.0f, 2.0f}, | |
| // {9.0f, 10.0f, 11.0f, 12.0f} | |
| // } | |
| // }, options_cpu); | |
| // torch::Tensor gamma_cpu = torch::tensor({1.0f, 1.0f, 1.0f, 1.0f}, options_cpu); | |
| // // Run CPU version | |
| // std::cout << "===== CPU IMPLEMENTATION RESULTS =====" << std::endl; | |
| // torch::Tensor out_cpu = rmsnorm_forward_cpu(x_cpu, gamma_cpu); | |
| // auto cpu_accessor = out_cpu.accessor<float, 3>(); | |
| // for (int b = 0; b < B; ++b) { | |
| // for (int s = 0; s < S; ++s) { | |
| // std::cout << "Row " << (b * S + s) << ": "; | |
| // for (int h = 0; h < H; ++h) { | |
| // std::cout << std::fixed << std::setprecision(6) << cpu_accessor[b][s][h] << " "; | |
| // } | |
| // std::cout << "\n"; | |
| // } | |
| // } | |
| // // Move tensors to CUDA for GPU version | |
| // auto cuda_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); | |
| // torch::Tensor x_cuda = x_cpu.to(torch::kCUDA); | |
| // torch::Tensor gamma_cuda = gamma_cpu.to(torch::kCUDA); | |
| // // Call the CUDA kernel wrapper | |
| // std::cout << "\n===== GPU IMPLEMENTATION RESULTS =====" << std::endl; | |
| // torch::Tensor out_cuda = rmsnorm_forward(x_cuda, gamma_cuda); | |
| // // Copy result back to CPU and print | |
| // auto gpu_result_on_cpu = out_cuda.cpu(); | |
| // auto gpu_accessor = gpu_result_on_cpu.accessor<float, 3>(); | |
| // for (int b = 0; b < B; ++b) { | |
| // for (int s = 0; s < S; ++s) { | |
| // std::cout << "Row " << (b * S + s) << ": "; | |
| // for (int h = 0; h < H; ++h) { | |
| // std::cout << std::fixed << std::setprecision(6) << gpu_accessor[b][s][h] << " "; | |
| // } | |
| // std::cout << "\n"; | |
| // } | |
| // } | |
| // // Check if results match | |
| // std::cout << "\n===== COMPARISON =====" << std::endl; | |
| // float max_diff = 0.0f; | |
| // for (int b = 0; b < B; ++b) { | |
| // for (int s = 0; s < S; ++s) { | |
| // for (int h = 0; h < H; ++h) { | |
| // float diff = std::abs(cpu_accessor[b][s][h] - gpu_accessor[b][s][h]); | |
| // max_diff = std::max(max_diff, diff); | |
| // } | |
| // } | |
| // } | |
| // std::cout << "Maximum difference between CPU and GPU results: " | |
| // << std::scientific << max_diff << std::endl; | |
| // std::cout << (max_diff < 1e-5 ? "PASSED: Results match!" : "FAILED: Results don't match!") << std::endl; | |
| // return 0; | |
| // } | |