File size: 1,954 Bytes
de70d68
 
be8191a
 
de70d68
567c8ad
 
8830f14
567c8ad
 
8830f14
22b535b
 
2bbed9b
 
 
22b535b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
---
license: mit
tags:
- kernel
---
# triton-kernels

triton-kernels is a set of kernels that enable fast moe on different architectures. These kernels are compatible with different precision (e.g bf16, mxfp4)

Original code here https://github.com/triton-lang/triton/tree/main/python/triton_kernels

The current version is the following commit 7d0efaa7231661299284a603512fce4fa255e62c

Note that we can't update those kernels as we wish as some commits might rely on triton main. We need to wait for a new release unfortunately. 
See releated issue https://github.com/triton-lang/triton/issues/7818


## Quickstart

```bash
uv run https://huggingface.co/kernels-community/triton_kernels/raw/main/readme_example.py
```

```python
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "torch",
#     "triton",
#     "numpy",
#     "kernels",
# ]
# ///

import torch
import sys
from kernels import get_kernel

torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Load triton_kernels module via kernels library
triton_kernels = get_kernel("kernels-community/triton_kernels")

# Access modules directly from the loaded kernel
swiglu = triton_kernels.swiglu
routing = triton_kernels.routing

# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# SwiGLU example
x = torch.randn(512, 1024, device=device, dtype=torch.bfloat16)
y = swiglu.swiglu_torch(x, 0.5, swiglu.PrecisionConfig(limit=1.0))
print(f"SwiGLU: {x.shape} -> {y.shape}")

# Routing example
logits = torch.randn(128, 8, device=device, dtype=torch.float16)
routing_data, gather_idx, scatter_idx = routing.routing_torch(logits, n_expts_act=2)
print(f"Routing: {routing_data.expt_hist.sum()} tokens routed")

# MoE integrated
n_tokens = routing_data.expt_hist.sum().item()
x_moe = torch.randn(n_tokens, 512, device=device, dtype=torch.bfloat16)
y_moe = swiglu.swiglu_torch(x_moe, 0.5, swiglu.PrecisionConfig(limit=1.0))
print(f"MoE SwiGLU: {x_moe.shape} -> {y_moe.shape}")
```