triton-kernels

triton-kernels is a set of kernels that enable fast moe on different architectures. These kernels are compatible with different precision (e.g bf16, mxfp4)

Original code here https://github.com/triton-lang/triton/tree/main/python/triton_kernels

The current version is the following commit 7d0efaa7231661299284a603512fce4fa255e62c

Note that we can't update those kernels as we wish as some commits might rely on triton main. We need to wait for a new release unfortunately. See releated issue https://github.com/triton-lang/triton/issues/7818

Quickstart

uv run https://huggingface.co/kernels-community/triton_kernels/raw/main/readme_example.py
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "torch",
#     "triton",
#     "numpy",
#     "kernels",
# ]
# ///

import torch
import sys
from kernels import get_kernel

torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Load triton_kernels module via kernels library
triton_kernels = get_kernel("kernels-community/triton_kernels")

# Access modules directly from the loaded kernel
swiglu = triton_kernels.swiglu
routing = triton_kernels.routing

# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# SwiGLU example
x = torch.randn(512, 1024, device=device, dtype=torch.bfloat16)
y = swiglu.swiglu_torch(x, 0.5, swiglu.PrecisionConfig(limit=1.0))
print(f"SwiGLU: {x.shape} -> {y.shape}")

# Routing example
logits = torch.randn(128, 8, device=device, dtype=torch.float16)
routing_data, gather_idx, scatter_idx = routing.routing_torch(logits, n_expts_act=2)
print(f"Routing: {routing_data.expt_hist.sum()} tokens routed")

# MoE integrated
n_tokens = routing_data.expt_hist.sum().item()
x_moe = torch.randn(n_tokens, 512, device=device, dtype=torch.bfloat16)
y_moe = swiglu.swiglu_torch(x_moe, 0.5, swiglu.PrecisionConfig(limit=1.0))
print(f"MoE SwiGLU: {x_moe.shape} -> {y_moe.shape}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support