fix-grid-limits
#2
by
3outeille
HF Staff
- opened
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch28-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch28-cxx11-cu129-x86_64-linux/megablocks/backend/kernels.py +7 -7
- tests/ops/binned_copy.py +145 -0
- torch-ext/megablocks/backend/kernels.py +7 -7
build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch28-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch28-cxx11-cu129-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
tests/ops/binned_copy.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytest
|
3 |
+
|
4 |
+
from megablocks.ops.binned_gather import BinnedGatherOp
|
5 |
+
|
6 |
+
binned_gather_triton = BinnedGatherOp.apply
|
7 |
+
|
8 |
+
def set_seeds(seed=0):
|
9 |
+
torch.manual_seed(seed)
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
torch.cuda.manual_seed_all(seed)
|
12 |
+
|
13 |
+
# Stress test expert_capacity, especially near and at the upper limit (e.g., 65535 for int16 indexing)
|
14 |
+
def make_stress_expert_capacity_tests():
|
15 |
+
tests = []
|
16 |
+
# Small cases for sanity
|
17 |
+
for seq_len, hidden_size, num_experts, top_k in [
|
18 |
+
(4, 2, 2, 1),
|
19 |
+
(4, 2, 2, 2),
|
20 |
+
(4, 2, 2, 4),
|
21 |
+
]:
|
22 |
+
for expert_capacity in [1, 2, 4]:
|
23 |
+
tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))
|
24 |
+
# Medium cases
|
25 |
+
for seq_len, hidden_size, num_experts, top_k in [
|
26 |
+
(1024, 1536, 4, 1),
|
27 |
+
(1024, 1536, 4, 2),
|
28 |
+
(1024, 1536, 4, 4),
|
29 |
+
(1024, 1536, 64, 1),
|
30 |
+
(1024, 1536, 64, 2),
|
31 |
+
(1024, 1536, 64, 4),
|
32 |
+
(1024, 1536, 128, 1),
|
33 |
+
(1024, 1536, 128, 2),
|
34 |
+
(1024, 1536, 128, 4),
|
35 |
+
]:
|
36 |
+
for expert_capacity in [1, 2, 4, 128, 1024]:
|
37 |
+
tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))
|
38 |
+
|
39 |
+
# Large cases, stress expert_capacity near 65536 (CUDA second dim grid limit)
|
40 |
+
for seq_len, hidden_size, num_experts, top_k in [
|
41 |
+
(4096, 768, 32, 4),
|
42 |
+
]:
|
43 |
+
for expert_capacity in [65535, 70000, 90000]:
|
44 |
+
tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))
|
45 |
+
|
46 |
+
return tuple(tests)
|
47 |
+
|
48 |
+
BINNED_GATHER_TESTS = make_stress_expert_capacity_tests()
|
49 |
+
|
50 |
+
@pytest.mark.parametrize(('seq_len', 'hidden_size', 'num_experts', 'top_k', 'expert_capacity'), BINNED_GATHER_TESTS)
|
51 |
+
def test_binned_gather(seq_len: int, hidden_size: int, num_experts: int, top_k: int, expert_capacity: int):
|
52 |
+
# NOTE: Capacity factor == 1.
|
53 |
+
set_seeds(42)
|
54 |
+
# Create the data and indices with gradient tracking
|
55 |
+
x = torch.arange(seq_len * hidden_size, device='cuda', dtype=torch.half).view(seq_len, hidden_size)
|
56 |
+
x.requires_grad_(True)
|
57 |
+
|
58 |
+
# Randomly assign tokens to experts.
|
59 |
+
top_expert = torch.randint(0, num_experts, (seq_len * top_k,), device='cuda', dtype=torch.int)
|
60 |
+
_, indices = torch.sort(top_expert)
|
61 |
+
bins = torch.cumsum(torch.bincount(top_expert, minlength=num_experts), 0).to(torch.int32)
|
62 |
+
# Example: counts is [12, 2, 3], the bins tensor will be [12, 14, 20]. This tells the gather function:
|
63 |
+
# Expert 0's assignments are in indices[0:12].
|
64 |
+
# Expert 1's assignments are in indices[12:14].
|
65 |
+
# Expert 2's assignments are in indices[14:20]. (we have num_tokens * 3)
|
66 |
+
|
67 |
+
def binned_gather_pytorch(
|
68 |
+
x: torch.Tensor,
|
69 |
+
indices: torch.Tensor,
|
70 |
+
bins: torch.Tensor,
|
71 |
+
expert_capacity: int,
|
72 |
+
top_k: int,
|
73 |
+
):
|
74 |
+
start = 0
|
75 |
+
out = torch.zeros((num_experts, expert_capacity, hidden_size), dtype=x.dtype, device=x.device)
|
76 |
+
for i in range(num_experts):
|
77 |
+
end = bins[i]
|
78 |
+
num_tokens = min(expert_capacity, end - start)
|
79 |
+
if num_tokens > 0:
|
80 |
+
# indices[start:end] are the indices for this expert
|
81 |
+
# For each slot j, get the input index and copy the row
|
82 |
+
idx = indices[start : start + num_tokens] // top_k
|
83 |
+
out[i, :num_tokens, :] = x[idx, :]
|
84 |
+
start = end
|
85 |
+
return out
|
86 |
+
|
87 |
+
out = binned_gather_triton(x, indices, bins, expert_capacity, top_k)
|
88 |
+
expected_out = binned_gather_pytorch(x, indices, bins, expert_capacity, top_k)
|
89 |
+
assert torch.all(torch.eq(out, expected_out))
|
90 |
+
|
91 |
+
# Test backward pass
|
92 |
+
grad_output = torch.arange(out.numel(), device=out.device, dtype=out.dtype).view_as(out)
|
93 |
+
out.backward(grad_output)
|
94 |
+
|
95 |
+
# Verify gradients were computed
|
96 |
+
assert x.grad is not None, "Gradients should be computed for input x"
|
97 |
+
assert x.grad.shape == x.shape, f"Gradient shape {x.grad.shape} should match input shape {x.shape}"
|
98 |
+
|
99 |
+
# Reference implementation for backward pass (binned_scatter)
|
100 |
+
def binned_scatter_pytorch(
|
101 |
+
x: torch.Tensor,
|
102 |
+
indices: torch.Tensor,
|
103 |
+
weights: torch.Tensor,
|
104 |
+
bins: torch.Tensor,
|
105 |
+
top_k: int,
|
106 |
+
):
|
107 |
+
# x: (ne, ec, hs)
|
108 |
+
# indices: (sl * top_k,)
|
109 |
+
# weights: (sl * top_k,)
|
110 |
+
# bins: (ne,)
|
111 |
+
# Output: (sl, hs)
|
112 |
+
out = torch.zeros((seq_len, hidden_size), device=x.device, dtype=x.dtype)
|
113 |
+
start = 0
|
114 |
+
for i in range(num_experts):
|
115 |
+
end = bins[i]
|
116 |
+
num_tokens = min(expert_capacity, end - start)
|
117 |
+
for j in range(num_tokens):
|
118 |
+
index = indices[start + j]
|
119 |
+
scale = weights[index] if weights is not None else 1.0
|
120 |
+
token_pos = index // top_k
|
121 |
+
|
122 |
+
out[token_pos, :] += scale * x[i, j, :]
|
123 |
+
start = end
|
124 |
+
return out
|
125 |
+
|
126 |
+
expected_grad = binned_scatter_pytorch(grad_output, indices, None, bins, top_k)
|
127 |
+
print(f"x.grad: {x.grad}")
|
128 |
+
print(f"expected_grad: {expected_grad}")
|
129 |
+
|
130 |
+
# Use torch.allclose instead of exact equality for floating point comparison
|
131 |
+
if torch.allclose(x.grad, expected_grad, rtol=1e-3, atol=1e-3):
|
132 |
+
print("✅ Success: Gradients match!")
|
133 |
+
else:
|
134 |
+
print("❌ Gradients don't match")
|
135 |
+
# Let's see if it's just a reordering issue
|
136 |
+
print("Checking if values match when sorted...")
|
137 |
+
grad_sorted = torch.sort(x.grad.flatten())[0]
|
138 |
+
expected_sorted = torch.sort(expected_grad.flatten())[0]
|
139 |
+
if torch.allclose(grad_sorted, expected_sorted, rtol=1e-3, atol=1e-3):
|
140 |
+
print("✅ Same values, different order - routing issue!")
|
141 |
+
else:
|
142 |
+
print("❌ Different values entirely")
|
143 |
+
|
144 |
+
print(f"\nTriton Output Shape: {x.grad.shape}")
|
145 |
+
print(f"PyTorch Output Shape: {expected_grad.shape}")
|
torch-ext/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|