Kernels
marcsun13 HF Staff commited on
Commit
f9a8cd3
·
verified ·
1 Parent(s): 10e8091

Upload folder using huggingface_hub

Browse files
Files changed (27) hide show
  1. build/torch-universal/triton_kernels/__init__.py +0 -3
  2. build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc +0 -0
  3. build/torch-universal/triton_kernels/_ops.py +2 -2
  4. build/torch-universal/triton_kernels/matmul_ogs.py +4 -4
  5. build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py +4 -4
  6. build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py +5 -5
  7. build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +4 -4
  8. build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py +1 -1
  9. build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py +2 -2
  10. build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +4 -4
  11. build/torch-universal/triton_kernels/numerics_details/flexpoint.py +1 -1
  12. build/torch-universal/triton_kernels/swiglu.py +2 -2
  13. build/torch-universal/triton_kernels/swiglu_details/_swiglu.py +1 -1
  14. build/torch-universal/triton_kernels/testing.py +1 -1
  15. build/torch-universal/triton_kernels/topk.py +3 -3
  16. torch-ext/triton_kernels/matmul_ogs.py +4 -4
  17. torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py +4 -4
  18. torch-ext/triton_kernels/matmul_ogs_details/_matmul_ogs.py +5 -5
  19. torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +4 -4
  20. torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py +1 -1
  21. torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py +2 -2
  22. torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +4 -4
  23. torch-ext/triton_kernels/numerics_details/flexpoint.py +1 -1
  24. torch-ext/triton_kernels/swiglu.py +2 -2
  25. torch-ext/triton_kernels/swiglu_details/_swiglu.py +1 -1
  26. torch-ext/triton_kernels/testing.py +1 -1
  27. torch-ext/triton_kernels/topk.py +3 -3
build/torch-universal/triton_kernels/__init__.py CHANGED
@@ -1,3 +0,0 @@
1
- from . import matmul_ogs
2
-
3
- __all__ = ["matmul_ogs"]
 
 
 
 
build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc CHANGED
Binary files a/build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc and b/build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc differ
 
build/torch-universal/triton_kernels/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._triton_kernels_8830f14_dirty
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_triton_kernels_8830f14_dirty::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._triton_kernels_10e8091_dirty
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_triton_kernels_10e8091_dirty::{op_name}"
build/torch-universal/triton_kernels/matmul_ogs.py CHANGED
@@ -7,10 +7,10 @@ import torch
7
  import triton
8
  from enum import Enum, auto
9
  # utilities
10
- from triton_kernels import target_info
11
- from triton_kernels.numerics import InFlexData, OutFlexData
12
- from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
13
- from triton_kernels.target_info import is_cuda
14
  # details
15
  from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx
16
  from .matmul_ogs_details._matmul_ogs import _matmul_ogs
 
7
  import triton
8
  from enum import Enum, auto
9
  # utilities
10
+ from . import target_info
11
+ from .numerics import InFlexData, OutFlexData
12
+ from .routing import GatherIndx, RoutingData, ScatterIndx
13
+ from .target_info import is_cuda
14
  # details
15
  from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx
16
  from .matmul_ogs_details._matmul_ogs import _matmul_ogs
build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py CHANGED
@@ -1,9 +1,9 @@
1
  import triton
2
  import triton.language as tl
3
- from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale, update_scale
4
- from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5
- from triton_kernels.target_info import cuda_capability_geq as _cuda_capability_geq
6
- from triton_kernels.target_info import is_hip as _is_hip
7
 
8
 
9
  # fmt: off
 
1
  import triton
2
  import triton.language as tl
3
+ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale
4
+ from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5
+ from ..target_info import cuda_capability_geq as _cuda_capability_geq
6
+ from ..target_info import is_hip as _is_hip
7
 
8
 
9
  # fmt: off
build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py CHANGED
@@ -2,11 +2,11 @@
2
  # fmt: off
3
  import triton
4
  import triton.language as tl
5
- from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
6
- from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
7
- from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
8
- from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
9
- from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
10
  from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
11
 
12
 
 
2
  # fmt: off
3
  import triton
4
  import triton.language as tl
5
+ from ..tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
6
+ from ..tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
7
+ from ..tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
8
+ from ..numerics_details.flexpoint import float_to_flex, load_scale
9
+ from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
10
  from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
11
 
12
 
build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py CHANGED
@@ -3,15 +3,15 @@
3
  import torch
4
  import triton
5
  import triton.language as tl
6
- from triton_kernels import target_info
7
- from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
8
- from triton_kernels.numerics_details.flexpoint import (
9
  float_to_flex,
10
  load_scale,
11
  nan_propagating_absmax_reduce,
12
  compute_scale,
13
  )
14
- from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
15
  from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
16
 
17
 
 
3
  import torch
4
  import triton
5
  import triton.language as tl
6
+ from . import target_info
7
+ from ..tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
8
+ from ..numerics_details.flexpoint import (
9
  float_to_flex,
10
  load_scale,
11
  nan_propagating_absmax_reduce,
12
  compute_scale,
13
  )
14
+ from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
15
  from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
16
 
17
 
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py CHANGED
@@ -2,7 +2,7 @@
2
  # fmt: off
3
  from dataclasses import dataclass
4
  import triton
5
- from triton_kernels.target_info import get_cdna_version
6
  import torch
7
  from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
8
 
 
2
  # fmt: off
3
  from dataclasses import dataclass
4
  import triton
5
+ from ..target_info import get_cdna_version
6
  import torch
7
  from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
8
 
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import triton
3
- from triton_kernels.target_info import get_cdna_version
4
- from triton_kernels.tensor import bitwidth
5
 
6
 
7
  def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config):
 
1
  import torch
2
  import triton
3
+ from ...target_info import get_cdna_version
4
+ from ...tensor import bitwidth
5
 
6
 
7
  def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config):
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
  import triton
3
- from triton_kernels import target_info
4
- from triton_kernels.tensor import get_layout, bitwidth, FP4
5
- from triton_kernels.tensor_details.layout import HopperMXScaleLayout
6
- from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
7
 
8
 
9
  def compute_grid_size(routing_data, m, n, block_m, block_n):
 
1
  import torch
2
  import triton
3
+ from ... import target_info
4
+ from ...tensor import get_layout, bitwidth, FP4
5
+ from ...tensor_details.layout import HopperMXScaleLayout
6
+ from ...numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
7
 
8
 
9
  def compute_grid_size(routing_data, m, n, block_m, block_n):
build/torch-universal/triton_kernels/numerics_details/flexpoint.py CHANGED
@@ -1,5 +1,5 @@
1
  from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
2
- from triton_kernels import target_info
3
  import triton
4
  import triton.language as tl
5
 
 
1
  from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
2
+ from .. import target_info
3
  import triton
4
  import triton.language as tl
5
 
build/torch-universal/triton_kernels/swiglu.py CHANGED
@@ -1,9 +1,9 @@
1
  from dataclasses import dataclass
2
- from triton_kernels.numerics import InFlexData, OutFlexData
3
  import torch
4
  import triton
5
  from .swiglu_details._swiglu import _swiglu, _swiglu_fn
6
- from triton_kernels import target_info
7
 
8
 
9
  @dataclass(frozen=True)
 
1
  from dataclasses import dataclass
2
+ from .numerics import InFlexData, OutFlexData
3
  import torch
4
  import triton
5
  from .swiglu_details._swiglu import _swiglu, _swiglu_fn
6
+ from . import target_info
7
 
8
 
9
  @dataclass(frozen=True)
build/torch-universal/triton_kernels/swiglu_details/_swiglu.py CHANGED
@@ -1,4 +1,4 @@
1
- from triton_kernels.numerics_details.flexpoint import load_scale, float_to_flex, update_scale
2
  import triton
3
  import triton.language as tl
4
 
 
1
+ from ..numerics_details.flexpoint import load_scale, float_to_flex, update_scale
2
  import triton
3
  import triton.language as tl
4
 
build/torch-universal/triton_kernels/testing.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import subprocess
5
  import sys
6
  import torch
7
- from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
8
 
9
 
10
  def assert_equal(ref, tri):
 
4
  import subprocess
5
  import sys
6
  import torch
7
+ from .numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
8
 
9
 
10
  def assert_equal(ref, tri):
build/torch-universal/triton_kernels/topk.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
  import triton
3
- from triton_kernels.topk_details._topk_forward import _topk_forward
4
- from triton_kernels.topk_details._topk_backward import _topk_backward
5
- from triton_kernels.tensor import Tensor, Bitmatrix
6
 
7
 
8
  def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):
 
1
  import torch
2
  import triton
3
+ from .topk_details._topk_forward import _topk_forward
4
+ from .topk_details._topk_backward import _topk_backward
5
+ from .tensor import Tensor, Bitmatrix
6
 
7
 
8
  def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):
torch-ext/triton_kernels/matmul_ogs.py CHANGED
@@ -7,10 +7,10 @@ import torch
7
  import triton
8
  from enum import Enum, auto
9
  # utilities
10
- from triton_kernels import target_info
11
- from triton_kernels.numerics import InFlexData, OutFlexData
12
- from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
13
- from triton_kernels.target_info import is_cuda
14
  # details
15
  from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx
16
  from .matmul_ogs_details._matmul_ogs import _matmul_ogs
 
7
  import triton
8
  from enum import Enum, auto
9
  # utilities
10
+ from . import target_info
11
+ from .numerics import InFlexData, OutFlexData
12
+ from .routing import GatherIndx, RoutingData, ScatterIndx
13
+ from .target_info import is_cuda
14
  # details
15
  from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx
16
  from .matmul_ogs_details._matmul_ogs import _matmul_ogs
torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py CHANGED
@@ -1,9 +1,9 @@
1
  import triton
2
  import triton.language as tl
3
- from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale, update_scale
4
- from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5
- from triton_kernels.target_info import cuda_capability_geq as _cuda_capability_geq
6
- from triton_kernels.target_info import is_hip as _is_hip
7
 
8
 
9
  # fmt: off
 
1
  import triton
2
  import triton.language as tl
3
+ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale
4
+ from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5
+ from ..target_info import cuda_capability_geq as _cuda_capability_geq
6
+ from ..target_info import is_hip as _is_hip
7
 
8
 
9
  # fmt: off
torch-ext/triton_kernels/matmul_ogs_details/_matmul_ogs.py CHANGED
@@ -2,11 +2,11 @@
2
  # fmt: off
3
  import triton
4
  import triton.language as tl
5
- from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
6
- from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
7
- from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
8
- from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
9
- from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
10
  from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
11
 
12
 
 
2
  # fmt: off
3
  import triton
4
  import triton.language as tl
5
+ from ..tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
6
+ from ..tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
7
+ from ..tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
8
+ from ..numerics_details.flexpoint import float_to_flex, load_scale
9
+ from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
10
  from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
11
 
12
 
torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py CHANGED
@@ -3,15 +3,15 @@
3
  import torch
4
  import triton
5
  import triton.language as tl
6
- from triton_kernels import target_info
7
- from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
8
- from triton_kernels.numerics_details.flexpoint import (
9
  float_to_flex,
10
  load_scale,
11
  nan_propagating_absmax_reduce,
12
  compute_scale,
13
  )
14
- from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
15
  from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
16
 
17
 
 
3
  import torch
4
  import triton
5
  import triton.language as tl
6
+ from . import target_info
7
+ from ..tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
8
+ from ..numerics_details.flexpoint import (
9
  float_to_flex,
10
  load_scale,
11
  nan_propagating_absmax_reduce,
12
  compute_scale,
13
  )
14
+ from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
15
  from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
16
 
17
 
torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py CHANGED
@@ -2,7 +2,7 @@
2
  # fmt: off
3
  from dataclasses import dataclass
4
  import triton
5
- from triton_kernels.target_info import get_cdna_version
6
  import torch
7
  from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
8
 
 
2
  # fmt: off
3
  from dataclasses import dataclass
4
  import triton
5
+ from ..target_info import get_cdna_version
6
  import torch
7
  from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
8
 
torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import triton
3
- from triton_kernels.target_info import get_cdna_version
4
- from triton_kernels.tensor import bitwidth
5
 
6
 
7
  def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config):
 
1
  import torch
2
  import triton
3
+ from ...target_info import get_cdna_version
4
+ from ...tensor import bitwidth
5
 
6
 
7
  def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config):
torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
  import triton
3
- from triton_kernels import target_info
4
- from triton_kernels.tensor import get_layout, bitwidth, FP4
5
- from triton_kernels.tensor_details.layout import HopperMXScaleLayout
6
- from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
7
 
8
 
9
  def compute_grid_size(routing_data, m, n, block_m, block_n):
 
1
  import torch
2
  import triton
3
+ from ... import target_info
4
+ from ...tensor import get_layout, bitwidth, FP4
5
+ from ...tensor_details.layout import HopperMXScaleLayout
6
+ from ...numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
7
 
8
 
9
  def compute_grid_size(routing_data, m, n, block_m, block_n):
torch-ext/triton_kernels/numerics_details/flexpoint.py CHANGED
@@ -1,5 +1,5 @@
1
  from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
2
- from triton_kernels import target_info
3
  import triton
4
  import triton.language as tl
5
 
 
1
  from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
2
+ from .. import target_info
3
  import triton
4
  import triton.language as tl
5
 
torch-ext/triton_kernels/swiglu.py CHANGED
@@ -1,9 +1,9 @@
1
  from dataclasses import dataclass
2
- from triton_kernels.numerics import InFlexData, OutFlexData
3
  import torch
4
  import triton
5
  from .swiglu_details._swiglu import _swiglu, _swiglu_fn
6
- from triton_kernels import target_info
7
 
8
 
9
  @dataclass(frozen=True)
 
1
  from dataclasses import dataclass
2
+ from .numerics import InFlexData, OutFlexData
3
  import torch
4
  import triton
5
  from .swiglu_details._swiglu import _swiglu, _swiglu_fn
6
+ from . import target_info
7
 
8
 
9
  @dataclass(frozen=True)
torch-ext/triton_kernels/swiglu_details/_swiglu.py CHANGED
@@ -1,4 +1,4 @@
1
- from triton_kernels.numerics_details.flexpoint import load_scale, float_to_flex, update_scale
2
  import triton
3
  import triton.language as tl
4
 
 
1
+ from ..numerics_details.flexpoint import load_scale, float_to_flex, update_scale
2
  import triton
3
  import triton.language as tl
4
 
torch-ext/triton_kernels/testing.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import subprocess
5
  import sys
6
  import torch
7
- from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
8
 
9
 
10
  def assert_equal(ref, tri):
 
4
  import subprocess
5
  import sys
6
  import torch
7
+ from .numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
8
 
9
 
10
  def assert_equal(ref, tri):
torch-ext/triton_kernels/topk.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
  import triton
3
- from triton_kernels.topk_details._topk_forward import _topk_forward
4
- from triton_kernels.topk_details._topk_backward import _topk_backward
5
- from triton_kernels.tensor import Tensor, Bitmatrix
6
 
7
 
8
  def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):
 
1
  import torch
2
  import triton
3
+ from .topk_details._topk_forward import _topk_forward
4
+ from .topk_details._topk_backward import _topk_backward
5
+ from .tensor import Tensor, Bitmatrix
6
 
7
 
8
  def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):