3outeille
commited on
Commit
·
5ba4d9f
1
Parent(s):
29553ae
update binned_copy kernels
Browse files- build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch28-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch28-cxx11-cu129-x86_64-linux/megablocks/backend/kernels.py +7 -7
- torch-ext/megablocks/backend/kernels.py +7 -7
build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch28-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
build/torch28-cxx11-cu129-x86_64-linux/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
torch-ext/megablocks/backend/kernels.py
CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
-
expert_idx = tl.program_id(
|
356 |
-
entry_idx = tl.program_id(
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
-
_binned_copy[(
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
-
_binned_copy[(
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
-
expert_idx = tl.program_id(
|
496 |
-
entry_idx = tl.program_id(
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
-
_binned_copy_wgrad[(
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|
|
|
352 |
SCALE: tl.constexpr,
|
353 |
):
|
354 |
# Load our indices into the output.
|
355 |
+
expert_idx = tl.program_id(1)
|
356 |
+
entry_idx = tl.program_id(0)
|
357 |
|
358 |
# Calculate our offset into the output.
|
359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
416 |
num_experts = bins.shape[0]
|
417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
418 |
|
419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
420 |
x,
|
421 |
out,
|
422 |
num_experts,
|
|
|
445 |
num_experts, expert_capacity, hidden_size = x.shape
|
446 |
tokens = indices.shape[0] // top_k
|
447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
449 |
out,
|
450 |
x,
|
451 |
num_experts,
|
|
|
492 |
BLOCK_X: tl.constexpr,
|
493 |
):
|
494 |
# Load our indices into the output.
|
495 |
+
expert_idx = tl.program_id(1)
|
496 |
+
entry_idx = tl.program_id(0)
|
497 |
|
498 |
# Calculate our offset into the output.
|
499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
543 |
num_experts, expert_capacity, hidden_size = x.shape
|
544 |
tokens = indices.shape[0] // top_k
|
545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
547 |
x,
|
548 |
grad,
|
549 |
out,
|