3outeille
commited on
Commit
·
89df403
1
Parent(s):
eb55039
add failing test when expert_capacity > 65535
Browse files- tests/ops/binned_copy.py +146 -0
tests/ops/binned_copy.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytest
|
3 |
+
|
4 |
+
from megablocks.ops.binned_gather import BinnedGatherOp
|
5 |
+
|
6 |
+
binned_gather_triton = BinnedGatherOp.apply
|
7 |
+
|
8 |
+
def set_seeds(seed=0):
|
9 |
+
torch.manual_seed(seed)
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
torch.cuda.manual_seed_all(seed)
|
12 |
+
|
13 |
+
# Stress test expert_capacity, especially near and at the upper limit (e.g., 65535 for int16 indexing)
|
14 |
+
def make_stress_expert_capacity_tests():
|
15 |
+
tests = []
|
16 |
+
# Small cases for sanity
|
17 |
+
for seq_len, hidden_size, num_experts, top_k in [
|
18 |
+
(4, 2, 2, 1),
|
19 |
+
(4, 2, 2, 2),
|
20 |
+
(4, 2, 2, 4),
|
21 |
+
]:
|
22 |
+
for expert_capacity in [1, 2, 4]:
|
23 |
+
tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))
|
24 |
+
# Medium cases
|
25 |
+
for seq_len, hidden_size, num_experts, top_k in [
|
26 |
+
(1024, 1536, 4, 1),
|
27 |
+
(1024, 1536, 4, 2),
|
28 |
+
(1024, 1536, 4, 4),
|
29 |
+
(1024, 1536, 64, 1),
|
30 |
+
(1024, 1536, 64, 2),
|
31 |
+
(1024, 1536, 64, 4),
|
32 |
+
(1024, 1536, 128, 1),
|
33 |
+
(1024, 1536, 128, 2),
|
34 |
+
(1024, 1536, 128, 4),
|
35 |
+
]:
|
36 |
+
for expert_capacity in [1, 2, 4, 128, 1024]:
|
37 |
+
tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))
|
38 |
+
|
39 |
+
# Large cases, stress expert_capacity near 65536 (CUDA second dim grid limit)
|
40 |
+
for seq_len, hidden_size, num_experts, top_k in [
|
41 |
+
(4096, 768, 32, 4),
|
42 |
+
]:
|
43 |
+
for expert_capacity in [65535, 70000, 100000, 1000000]:
|
44 |
+
tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))
|
45 |
+
|
46 |
+
return tuple(tests)
|
47 |
+
|
48 |
+
BINNED_GATHER_TESTS = make_stress_expert_capacity_tests()
|
49 |
+
|
50 |
+
@pytest.mark.parametrize(('seq_len', 'hidden_size', 'num_experts', 'top_k', 'expert_capacity'), BINNED_GATHER_TESTS)
|
51 |
+
def test_binned_gather(seq_len: int, hidden_size: int, num_experts: int, top_k: int, expert_capacity: int):
|
52 |
+
# NOTE: Capacity factor == 1.
|
53 |
+
set_seeds(42)
|
54 |
+
# Create the data and indices with gradient tracking
|
55 |
+
x = torch.arange(seq_len * hidden_size, device='cuda', dtype=torch.half).view(seq_len, hidden_size)
|
56 |
+
x.requires_grad_(True)
|
57 |
+
|
58 |
+
# Randomly assign tokens to experts.
|
59 |
+
top_expert = torch.randint(0, num_experts, (seq_len * top_k,), device='cuda', dtype=torch.int)
|
60 |
+
_, indices = torch.sort(top_expert)
|
61 |
+
bins = torch.cumsum(torch.bincount(top_expert, minlength=num_experts), 0).to(torch.int32)
|
62 |
+
# Example: counts is [12, 2, 3], the bins tensor will be [12, 14, 20]. This tells the gather function:
|
63 |
+
# Expert 0's assignments are in indices[0:12].
|
64 |
+
# Expert 1's assignments are in indices[12:14].
|
65 |
+
# Expert 2's assignments are in indices[14:20]. (we have num_tokens * 3)
|
66 |
+
|
67 |
+
def binned_gather_pytorch(
|
68 |
+
x: torch.Tensor,
|
69 |
+
indices: torch.Tensor,
|
70 |
+
bins: torch.Tensor,
|
71 |
+
expert_capacity: int,
|
72 |
+
top_k: int,
|
73 |
+
):
|
74 |
+
start = 0
|
75 |
+
out = torch.zeros((num_experts, expert_capacity, hidden_size), dtype=x.dtype, device=x.device)
|
76 |
+
for i in range(num_experts):
|
77 |
+
end = bins[i]
|
78 |
+
num_tokens = min(expert_capacity, end - start)
|
79 |
+
if num_tokens > 0:
|
80 |
+
# indices[start:end] are the indices for this expert
|
81 |
+
# For each slot j, get the input index and copy the row
|
82 |
+
idx = indices[start : start + num_tokens] // top_k
|
83 |
+
print(f"Expert {i}: indices[{start}:{start + num_tokens}] = {indices[start : start + num_tokens]} -> tokens {idx}")
|
84 |
+
out[i, :num_tokens, :] = x[idx, :]
|
85 |
+
start = end
|
86 |
+
return out
|
87 |
+
|
88 |
+
out = binned_gather_triton(x, indices, bins, expert_capacity, top_k)
|
89 |
+
expected_out = binned_gather_pytorch(x, indices, bins, expert_capacity, top_k)
|
90 |
+
assert torch.all(torch.eq(out, expected_out))
|
91 |
+
|
92 |
+
# Test backward pass
|
93 |
+
grad_output = torch.arange(out.numel(), device=out.device, dtype=out.dtype).view_as(out)
|
94 |
+
out.backward(grad_output)
|
95 |
+
|
96 |
+
# Verify gradients were computed
|
97 |
+
assert x.grad is not None, "Gradients should be computed for input x"
|
98 |
+
assert x.grad.shape == x.shape, f"Gradient shape {x.grad.shape} should match input shape {x.shape}"
|
99 |
+
|
100 |
+
# Reference implementation for backward pass (binned_scatter)
|
101 |
+
def binned_scatter_pytorch(
|
102 |
+
x: torch.Tensor,
|
103 |
+
indices: torch.Tensor,
|
104 |
+
weights: torch.Tensor,
|
105 |
+
bins: torch.Tensor,
|
106 |
+
top_k: int,
|
107 |
+
):
|
108 |
+
# x: (ne, ec, hs)
|
109 |
+
# indices: (sl * top_k,)
|
110 |
+
# weights: (sl * top_k,)
|
111 |
+
# bins: (ne,)
|
112 |
+
# Output: (sl, hs)
|
113 |
+
out = torch.zeros((seq_len, hidden_size), device=x.device, dtype=x.dtype)
|
114 |
+
start = 0
|
115 |
+
for i in range(num_experts):
|
116 |
+
end = bins[i]
|
117 |
+
num_tokens = min(expert_capacity, end - start)
|
118 |
+
for j in range(num_tokens):
|
119 |
+
index = indices[start + j]
|
120 |
+
scale = weights[index] if weights is not None else 1.0
|
121 |
+
token_pos = index // top_k
|
122 |
+
|
123 |
+
out[token_pos, :] += scale * x[i, j, :]
|
124 |
+
start = end
|
125 |
+
return out
|
126 |
+
|
127 |
+
expected_grad = binned_scatter_pytorch(grad_output, indices, None, bins, top_k)
|
128 |
+
print(f"x.grad: {x.grad}")
|
129 |
+
print(f"expected_grad: {expected_grad}")
|
130 |
+
|
131 |
+
# Use torch.allclose instead of exact equality for floating point comparison
|
132 |
+
if torch.allclose(x.grad, expected_grad, rtol=1e-3, atol=1e-3):
|
133 |
+
print("✅ Success: Gradients match!")
|
134 |
+
else:
|
135 |
+
print("❌ Gradients don't match")
|
136 |
+
# Let's see if it's just a reordering issue
|
137 |
+
print("Checking if values match when sorted...")
|
138 |
+
grad_sorted = torch.sort(x.grad.flatten())[0]
|
139 |
+
expected_sorted = torch.sort(expected_grad.flatten())[0]
|
140 |
+
if torch.allclose(grad_sorted, expected_sorted, rtol=1e-3, atol=1e-3):
|
141 |
+
print("✅ Same values, different order - routing issue!")
|
142 |
+
else:
|
143 |
+
print("❌ Different values entirely")
|
144 |
+
|
145 |
+
print(f"\nTriton Output Shape: {x.grad.shape}")
|
146 |
+
print(f"PyTorch Output Shape: {expected_grad.shape}")
|