3outeille commited on
Commit
29553ae
·
1 Parent(s): 89df403
Files changed (1) hide show
  1. tests/ops/binned_copy.py +1 -2
tests/ops/binned_copy.py CHANGED
@@ -40,7 +40,7 @@ def make_stress_expert_capacity_tests():
40
  for seq_len, hidden_size, num_experts, top_k in [
41
  (4096, 768, 32, 4),
42
  ]:
43
- for expert_capacity in [65535, 70000, 100000, 1000000]:
44
  tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))
45
 
46
  return tuple(tests)
@@ -80,7 +80,6 @@ def test_binned_gather(seq_len: int, hidden_size: int, num_experts: int, top_k:
80
  # indices[start:end] are the indices for this expert
81
  # For each slot j, get the input index and copy the row
82
  idx = indices[start : start + num_tokens] // top_k
83
- print(f"Expert {i}: indices[{start}:{start + num_tokens}] = {indices[start : start + num_tokens]} -> tokens {idx}")
84
  out[i, :num_tokens, :] = x[idx, :]
85
  start = end
86
  return out
 
40
  for seq_len, hidden_size, num_experts, top_k in [
41
  (4096, 768, 32, 4),
42
  ]:
43
+ for expert_capacity in [65535, 70000, 90000]:
44
  tests.append((seq_len, hidden_size, num_experts, top_k, expert_capacity))
45
 
46
  return tuple(tests)
 
80
  # indices[start:end] are the indices for this expert
81
  # For each slot j, get the input index and copy the row
82
  idx = indices[start : start + num_tokens] // top_k
 
83
  out[i, :num_tokens, :] = x[idx, :]
84
  start = end
85
  return out