3outeille
commited on
Commit
·
5ba4d9f
1
Parent(s):
29553ae
update binned_copy kernels
Browse files- build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch28-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py +7 -7
- build/torch28-cxx11-cu129-x86_64-linux/megablocks/backend/kernels.py +7 -7
- torch-ext/megablocks/backend/kernels.py +7 -7
build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py
CHANGED
|
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
-
expert_idx = tl.program_id(
|
| 356 |
-
entry_idx = tl.program_id(
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
-
_binned_copy[(
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
-
_binned_copy[(
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
-
expert_idx = tl.program_id(
|
| 496 |
-
entry_idx = tl.program_id(
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
-
_binned_copy_wgrad[(
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
|
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
+
expert_idx = tl.program_id(1)
|
| 356 |
+
entry_idx = tl.program_id(0)
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
+
expert_idx = tl.program_id(1)
|
| 496 |
+
entry_idx = tl.program_id(0)
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py
CHANGED
|
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
-
expert_idx = tl.program_id(
|
| 356 |
-
entry_idx = tl.program_id(
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
-
_binned_copy[(
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
-
_binned_copy[(
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
-
expert_idx = tl.program_id(
|
| 496 |
-
entry_idx = tl.program_id(
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
-
_binned_copy_wgrad[(
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
|
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
+
expert_idx = tl.program_id(1)
|
| 356 |
+
entry_idx = tl.program_id(0)
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
+
expert_idx = tl.program_id(1)
|
| 496 |
+
entry_idx = tl.program_id(0)
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py
CHANGED
|
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
-
expert_idx = tl.program_id(
|
| 356 |
-
entry_idx = tl.program_id(
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
-
_binned_copy[(
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
-
_binned_copy[(
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
-
expert_idx = tl.program_id(
|
| 496 |
-
entry_idx = tl.program_id(
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
-
_binned_copy_wgrad[(
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
|
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
+
expert_idx = tl.program_id(1)
|
| 356 |
+
entry_idx = tl.program_id(0)
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
+
expert_idx = tl.program_id(1)
|
| 496 |
+
entry_idx = tl.program_id(0)
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
build/torch28-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py
CHANGED
|
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
-
expert_idx = tl.program_id(
|
| 356 |
-
entry_idx = tl.program_id(
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
-
_binned_copy[(
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
-
_binned_copy[(
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
-
expert_idx = tl.program_id(
|
| 496 |
-
entry_idx = tl.program_id(
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
-
_binned_copy_wgrad[(
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
|
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
+
expert_idx = tl.program_id(1)
|
| 356 |
+
entry_idx = tl.program_id(0)
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
+
expert_idx = tl.program_id(1)
|
| 496 |
+
entry_idx = tl.program_id(0)
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py
CHANGED
|
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
-
expert_idx = tl.program_id(
|
| 356 |
-
entry_idx = tl.program_id(
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
-
_binned_copy[(
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
-
_binned_copy[(
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
-
expert_idx = tl.program_id(
|
| 496 |
-
entry_idx = tl.program_id(
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
-
_binned_copy_wgrad[(
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
|
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
+
expert_idx = tl.program_id(1)
|
| 356 |
+
entry_idx = tl.program_id(0)
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
+
expert_idx = tl.program_id(1)
|
| 496 |
+
entry_idx = tl.program_id(0)
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
build/torch28-cxx11-cu129-x86_64-linux/megablocks/backend/kernels.py
CHANGED
|
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
-
expert_idx = tl.program_id(
|
| 356 |
-
entry_idx = tl.program_id(
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
-
_binned_copy[(
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
-
_binned_copy[(
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
-
expert_idx = tl.program_id(
|
| 496 |
-
entry_idx = tl.program_id(
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
-
_binned_copy_wgrad[(
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
|
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
+
expert_idx = tl.program_id(1)
|
| 356 |
+
entry_idx = tl.program_id(0)
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
+
expert_idx = tl.program_id(1)
|
| 496 |
+
entry_idx = tl.program_id(0)
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
torch-ext/megablocks/backend/kernels.py
CHANGED
|
@@ -352,8 +352,8 @@ def _binned_copy(
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
-
expert_idx = tl.program_id(
|
| 356 |
-
entry_idx = tl.program_id(
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
-
_binned_copy[(
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
-
_binned_copy[(
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
-
expert_idx = tl.program_id(
|
| 496 |
-
entry_idx = tl.program_id(
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
-
_binned_copy_wgrad[(
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|
|
|
|
| 352 |
SCALE: tl.constexpr,
|
| 353 |
):
|
| 354 |
# Load our indices into the output.
|
| 355 |
+
expert_idx = tl.program_id(1)
|
| 356 |
+
entry_idx = tl.program_id(0)
|
| 357 |
|
| 358 |
# Calculate our offset into the output.
|
| 359 |
index_b = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 416 |
num_experts = bins.shape[0]
|
| 417 |
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
| 418 |
|
| 419 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 420 |
x,
|
| 421 |
out,
|
| 422 |
num_experts,
|
|
|
|
| 445 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 446 |
tokens = indices.shape[0] // top_k
|
| 447 |
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
| 448 |
+
_binned_copy[(expert_capacity, num_experts)](
|
| 449 |
out,
|
| 450 |
x,
|
| 451 |
num_experts,
|
|
|
|
| 492 |
BLOCK_X: tl.constexpr,
|
| 493 |
):
|
| 494 |
# Load our indices into the output.
|
| 495 |
+
expert_idx = tl.program_id(1)
|
| 496 |
+
entry_idx = tl.program_id(0)
|
| 497 |
|
| 498 |
# Calculate our offset into the output.
|
| 499 |
index_x = expert_idx * expert_capacity + entry_idx
|
|
|
|
| 543 |
num_experts, expert_capacity, hidden_size = x.shape
|
| 544 |
tokens = indices.shape[0] // top_k
|
| 545 |
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
| 546 |
+
_binned_copy_wgrad[(expert_capacity, num_experts)](
|
| 547 |
x,
|
| 548 |
grad,
|
| 549 |
out,
|