3outeille commited on
Commit
5ba4d9f
·
1 Parent(s): 29553ae

update binned_copy kernels

Browse files
build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
- expert_idx = tl.program_id(0)
356
- entry_idx = tl.program_id(1)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
- _binned_copy[(num_experts, expert_capacity)](
420
  x,
421
  out,
422
  num_experts,
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
- _binned_copy[(num_experts, expert_capacity)](
449
  out,
450
  x,
451
  num_experts,
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
- expert_idx = tl.program_id(0)
496
- entry_idx = tl.program_id(1)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
- _binned_copy_wgrad[(num_experts, expert_capacity)](
547
  x,
548
  grad,
549
  out,
 
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
+ expert_idx = tl.program_id(1)
356
+ entry_idx = tl.program_id(0)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
 
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
+ _binned_copy[(expert_capacity, num_experts)](
420
  x,
421
  out,
422
  num_experts,
 
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
+ _binned_copy[(expert_capacity, num_experts)](
449
  out,
450
  x,
451
  num_experts,
 
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
+ expert_idx = tl.program_id(1)
496
+ entry_idx = tl.program_id(0)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
 
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
+ _binned_copy_wgrad[(expert_capacity, num_experts)](
547
  x,
548
  grad,
549
  out,
build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
- expert_idx = tl.program_id(0)
356
- entry_idx = tl.program_id(1)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
- _binned_copy[(num_experts, expert_capacity)](
420
  x,
421
  out,
422
  num_experts,
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
- _binned_copy[(num_experts, expert_capacity)](
449
  out,
450
  x,
451
  num_experts,
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
- expert_idx = tl.program_id(0)
496
- entry_idx = tl.program_id(1)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
- _binned_copy_wgrad[(num_experts, expert_capacity)](
547
  x,
548
  grad,
549
  out,
 
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
+ expert_idx = tl.program_id(1)
356
+ entry_idx = tl.program_id(0)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
 
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
+ _binned_copy[(expert_capacity, num_experts)](
420
  x,
421
  out,
422
  num_experts,
 
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
+ _binned_copy[(expert_capacity, num_experts)](
449
  out,
450
  x,
451
  num_experts,
 
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
+ expert_idx = tl.program_id(1)
496
+ entry_idx = tl.program_id(0)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
 
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
+ _binned_copy_wgrad[(expert_capacity, num_experts)](
547
  x,
548
  grad,
549
  out,
build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
- expert_idx = tl.program_id(0)
356
- entry_idx = tl.program_id(1)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
- _binned_copy[(num_experts, expert_capacity)](
420
  x,
421
  out,
422
  num_experts,
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
- _binned_copy[(num_experts, expert_capacity)](
449
  out,
450
  x,
451
  num_experts,
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
- expert_idx = tl.program_id(0)
496
- entry_idx = tl.program_id(1)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
- _binned_copy_wgrad[(num_experts, expert_capacity)](
547
  x,
548
  grad,
549
  out,
 
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
+ expert_idx = tl.program_id(1)
356
+ entry_idx = tl.program_id(0)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
 
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
+ _binned_copy[(expert_capacity, num_experts)](
420
  x,
421
  out,
422
  num_experts,
 
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
+ _binned_copy[(expert_capacity, num_experts)](
449
  out,
450
  x,
451
  num_experts,
 
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
+ expert_idx = tl.program_id(1)
496
+ entry_idx = tl.program_id(0)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
 
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
+ _binned_copy_wgrad[(expert_capacity, num_experts)](
547
  x,
548
  grad,
549
  out,
build/torch28-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
- expert_idx = tl.program_id(0)
356
- entry_idx = tl.program_id(1)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
- _binned_copy[(num_experts, expert_capacity)](
420
  x,
421
  out,
422
  num_experts,
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
- _binned_copy[(num_experts, expert_capacity)](
449
  out,
450
  x,
451
  num_experts,
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
- expert_idx = tl.program_id(0)
496
- entry_idx = tl.program_id(1)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
- _binned_copy_wgrad[(num_experts, expert_capacity)](
547
  x,
548
  grad,
549
  out,
 
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
+ expert_idx = tl.program_id(1)
356
+ entry_idx = tl.program_id(0)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
 
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
+ _binned_copy[(expert_capacity, num_experts)](
420
  x,
421
  out,
422
  num_experts,
 
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
+ _binned_copy[(expert_capacity, num_experts)](
449
  out,
450
  x,
451
  num_experts,
 
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
+ expert_idx = tl.program_id(1)
496
+ entry_idx = tl.program_id(0)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
 
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
+ _binned_copy_wgrad[(expert_capacity, num_experts)](
547
  x,
548
  grad,
549
  out,
build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
- expert_idx = tl.program_id(0)
356
- entry_idx = tl.program_id(1)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
- _binned_copy[(num_experts, expert_capacity)](
420
  x,
421
  out,
422
  num_experts,
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
- _binned_copy[(num_experts, expert_capacity)](
449
  out,
450
  x,
451
  num_experts,
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
- expert_idx = tl.program_id(0)
496
- entry_idx = tl.program_id(1)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
- _binned_copy_wgrad[(num_experts, expert_capacity)](
547
  x,
548
  grad,
549
  out,
 
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
+ expert_idx = tl.program_id(1)
356
+ entry_idx = tl.program_id(0)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
 
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
+ _binned_copy[(expert_capacity, num_experts)](
420
  x,
421
  out,
422
  num_experts,
 
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
+ _binned_copy[(expert_capacity, num_experts)](
449
  out,
450
  x,
451
  num_experts,
 
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
+ expert_idx = tl.program_id(1)
496
+ entry_idx = tl.program_id(0)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
 
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
+ _binned_copy_wgrad[(expert_capacity, num_experts)](
547
  x,
548
  grad,
549
  out,
build/torch28-cxx11-cu129-x86_64-linux/megablocks/backend/kernels.py CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
- expert_idx = tl.program_id(0)
356
- entry_idx = tl.program_id(1)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
- _binned_copy[(num_experts, expert_capacity)](
420
  x,
421
  out,
422
  num_experts,
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
- _binned_copy[(num_experts, expert_capacity)](
449
  out,
450
  x,
451
  num_experts,
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
- expert_idx = tl.program_id(0)
496
- entry_idx = tl.program_id(1)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
- _binned_copy_wgrad[(num_experts, expert_capacity)](
547
  x,
548
  grad,
549
  out,
 
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
+ expert_idx = tl.program_id(1)
356
+ entry_idx = tl.program_id(0)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
 
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
+ _binned_copy[(expert_capacity, num_experts)](
420
  x,
421
  out,
422
  num_experts,
 
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
+ _binned_copy[(expert_capacity, num_experts)](
449
  out,
450
  x,
451
  num_experts,
 
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
+ expert_idx = tl.program_id(1)
496
+ entry_idx = tl.program_id(0)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
 
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
+ _binned_copy_wgrad[(expert_capacity, num_experts)](
547
  x,
548
  grad,
549
  out,
torch-ext/megablocks/backend/kernels.py CHANGED
@@ -352,8 +352,8 @@ def _binned_copy(
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
- expert_idx = tl.program_id(0)
356
- entry_idx = tl.program_id(1)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
@@ -416,7 +416,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
- _binned_copy[(num_experts, expert_capacity)](
420
  x,
421
  out,
422
  num_experts,
@@ -445,7 +445,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
- _binned_copy[(num_experts, expert_capacity)](
449
  out,
450
  x,
451
  num_experts,
@@ -492,8 +492,8 @@ def _binned_copy_wgrad(
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
- expert_idx = tl.program_id(0)
496
- entry_idx = tl.program_id(1)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
@@ -543,7 +543,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
- _binned_copy_wgrad[(num_experts, expert_capacity)](
547
  x,
548
  grad,
549
  out,
 
352
  SCALE: tl.constexpr,
353
  ):
354
  # Load our indices into the output.
355
+ expert_idx = tl.program_id(1)
356
+ entry_idx = tl.program_id(0)
357
 
358
  # Calculate our offset into the output.
359
  index_b = expert_idx * expert_capacity + entry_idx
 
416
  num_experts = bins.shape[0]
417
  out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
 
419
+ _binned_copy[(expert_capacity, num_experts)](
420
  x,
421
  out,
422
  num_experts,
 
445
  num_experts, expert_capacity, hidden_size = x.shape
446
  tokens = indices.shape[0] // top_k
447
  out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
+ _binned_copy[(expert_capacity, num_experts)](
449
  out,
450
  x,
451
  num_experts,
 
492
  BLOCK_X: tl.constexpr,
493
  ):
494
  # Load our indices into the output.
495
+ expert_idx = tl.program_id(1)
496
+ entry_idx = tl.program_id(0)
497
 
498
  # Calculate our offset into the output.
499
  index_x = expert_idx * expert_capacity + entry_idx
 
543
  num_experts, expert_capacity, hidden_size = x.shape
544
  tokens = indices.shape[0] // top_k
545
  out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
+ _binned_copy_wgrad[(expert_capacity, num_experts)](
547
  x,
548
  grad,
549
  out,