File size: 196 Bytes
25b4ce2 |
1 2 3 4 5 6 7 8 9 |
#pragma once
#include <torch/extension.h>
void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output); |