diff --git a/app.py b/app.py
index 974a442a7cab0fd17cad9a20e6e6788ba68b9e86..61f771877272c3a46f7fad71564a92a3787be668 100644
--- a/app.py
+++ b/app.py
@@ -1,15 +1,13 @@
-import os
-# import spaces
+import shlex
+import subprocess
+import spaces
 import torch
 
-os.system("nvidia-smi")
-print("TORCH_CUDA", torch.cuda.is_available())
-
-
 # install packages for mamba
 def install():
     print("Install personal packages", flush=True)
-    os.system("bash install.sh")
+    subprocess.run(shlex.split("pip install causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl"))
+    subprocess.run(shlex.split("pip install mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl"))
 
 install()
 
@@ -25,7 +23,7 @@ from videomamba_video import videomamba_tiny
 from kinetics_class_index import kinetics_classnames
 from imagenet_class_index import imagenet_classnames
 from transforms import (
-    GroupNormalize, GroupScale, GroupCenterCrop, 
+    GroupNormalize, GroupScale, GroupCenterCrop,
     Stack, ToTorchFormatTensor
 )
 
@@ -38,7 +36,7 @@ from huggingface_hub import hf_hub_download
 device = "cuda"
 model_video_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_k400_f16_res224.pth")
 model_image_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_in1k_res224.pth")
-# Pick a pretrained model 
+# Pick a pretrained model
 model_video = videomamba_tiny(num_classes=400, num_frames=16)
 video_sd = torch.load(model_video_path, map_location='cpu')
 model_video.load_state_dict(video_sd)
@@ -55,7 +53,7 @@ for k, v in kinetics_classnames.items():
     kinetics_id_to_classname[k] = v
 imagenet_id_to_classname = {}
 for k, v in imagenet_classnames.items():
-    imagenet_id_to_classname[k] = v[1] 
+    imagenet_id_to_classname[k] = v[1]
 
 
 def get_index(num_frames, num_segments=8):
@@ -83,7 +81,7 @@ def load_video(video_path):
         GroupCenterCrop(crop_size),
         Stack(),
         ToTorchFormatTensor(),
-        GroupNormalize(input_mean, input_std) 
+        GroupNormalize(input_mean, input_std)
     ])
 
     images_group = list()
@@ -92,28 +90,24 @@ def load_video(video_path):
         images_group.append(img)
     torch_imgs = transform(images_group)
     return torch_imgs
-    
 
-# @spaces.GPU
+
+@spaces.GPU
 def inference_video(video):
     vid = load_video(video)
-    
+
     # The model expects inputs of shape: B x C x H x W
     TC, H, W = vid.shape
     inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
-    
+
     with torch.no_grad():
         prediction = model_video(inputs.to(device))
         prediction = F.softmax(prediction, dim=1).flatten()
 
     return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
-    
-
-def set_example_video(example: list) -> dict:
-    return gr.Video.update(value=example[0])
 
 
-# @spaces.GPU
+@spaces.GPU
 def inference_image(img):
     image = img
     image_transform = T.Compose(
@@ -125,10 +119,10 @@ def inference_image(img):
     ]
     )
     image = image_transform(image)
-    
+
     # The model expects inputs of shape: B x C x H x W
     image = image.unsqueeze(0)
-    
+
     with torch.no_grad():
         prediction = model_image(image.to(device))
         prediction = F.softmax(prediction, dim=1).flatten()
@@ -136,10 +130,6 @@ def inference_image(img):
     return {imagenet_id_to_classname[str(i)]: float(prediction[i]) for i in range(1000)}
 
 
-def set_example_image(example: list) -> dict:
-    return gr.Image.update(value=example[0])
-
-
 demo = gr.Blocks()
 with demo:
     gr.Markdown(
@@ -154,26 +144,26 @@ with demo:
         with gr.Row():
             with gr.Column():
                 with gr.Row():
-                    input_video = gr.Video(label='Input Video').style(height=360)
+                    input_video = gr.Video(label='Input Video', height=360)
                 with gr.Row():
                     submit_video_button = gr.Button('Submit')
             with gr.Column():
                     label_video = gr.Label(num_top_classes=5)
         with gr.Row():
-            example_videos = gr.Dataset(components=[input_video], samples=[['./videos/hitting_baseball.mp4'], ['./videos/hoverboarding.mp4'], ['./videos/yoga.mp4']])
-        
+            gr.Examples(examples=['./videos/hitting_baseball.mp4', './videos/hoverboarding.mp4', './videos/yoga.mp4'], inputs=input_video, outputs=label_video, fn=inference_video, cache_examples=True)
+
     with gr.Tab("Image"):
         # with gr.Box():
         with gr.Row():
             with gr.Column():
                 with gr.Row():
-                    input_image = gr.Image(label='Input Image', type='pil').style(height=360)
+                    input_image = gr.Image(label='Input Image', type='pil', height=360)
                 with gr.Row():
                     submit_image_button = gr.Button('Submit')
             with gr.Column():
                 label_image = gr.Label(num_top_classes=5)
         with gr.Row():
-            example_images = gr.Dataset(components=[input_image], samples=[['./images/cat.png'], ['./images/dog.png'], ['./images/panda.png']])
+            gr.Examples(examples=['./images/cat.png', './images/dog.png', './images/panda.png'], inputs=input_image, outputs=label_image, fn=inference_image, cache_examples=True)
 
     gr.Markdown(
         """
@@ -182,9 +172,6 @@ with demo:
     )
 
     submit_video_button.click(fn=inference_video, inputs=input_video, outputs=label_video)
-    example_videos.click(fn=set_example_video, inputs=example_videos, outputs=example_videos._components)
     submit_image_button.click(fn=inference_image, inputs=input_image, outputs=label_image)
-    example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images._components)
 
-demo.launch(enable_queue=True)
-# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
\ No newline at end of file
+demo.queue(max_size=20).launch()
diff --git a/causal-conv1d/AUTHORS b/causal-conv1d/AUTHORS
deleted file mode 100644
index 88193855314bb723ced1860384e417954f559700..0000000000000000000000000000000000000000
--- a/causal-conv1d/AUTHORS
+++ /dev/null
@@ -1 +0,0 @@
-Tri Dao, tri@tridao.me
diff --git a/causal-conv1d/LICENSE b/causal-conv1d/LICENSE
deleted file mode 100644
index 5860e4b33f3d9d85fc636137c559331d51783a5b..0000000000000000000000000000000000000000
--- a/causal-conv1d/LICENSE
+++ /dev/null
@@ -1,29 +0,0 @@
-BSD 3-Clause License
-
-Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-* Redistributions of source code must retain the above copyright notice, this
-  list of conditions and the following disclaimer.
-
-* Redistributions in binary form must reproduce the above copyright notice,
-  this list of conditions and the following disclaimer in the documentation
-  and/or other materials provided with the distribution.
-
-* Neither the name of the copyright holder nor the names of its
-  contributors may be used to endorse or promote products derived from
-  this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/causal-conv1d/README.md b/causal-conv1d/README.md
deleted file mode 100644
index 4e905425a650d77c5c4854e4c4a261778c4d2690..0000000000000000000000000000000000000000
--- a/causal-conv1d/README.md
+++ /dev/null
@@ -1 +0,0 @@
-# Causal depthwise conv1d in CUDA with a PyTorch interface
diff --git a/causal-conv1d/causal_conv1d/__init__.py b/causal-conv1d/causal_conv1d/__init__.py
deleted file mode 100644
index cc4d610a1e557cabd723fb6e33438f03c5c4bf66..0000000000000000000000000000000000000000
--- a/causal-conv1d/causal_conv1d/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-__version__ = "1.0.0"
-
-from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
diff --git a/causal-conv1d/causal_conv1d/causal_conv1d_interface.py b/causal-conv1d/causal_conv1d/causal_conv1d_interface.py
deleted file mode 100644
index f66143c39e767572ca12112811a384239b8beb63..0000000000000000000000000000000000000000
--- a/causal-conv1d/causal_conv1d/causal_conv1d_interface.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# Copyright (c) 2023, Tri Dao.
-
-import torch
-import torch.nn.functional as F
-
-
-import causal_conv1d_cuda
-
-
-class CausalConv1dFn(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, x, weight, bias=None, activation=None):
-        if activation not in [None, "silu", "swish"]:
-            raise NotImplementedError("activation must be None, silu, or swish")
-        if x.stride(2) != 1 and x.stride(1) != 1:
-            x = x.contiguous()
-        bias = bias.contiguous() if bias is not None else None
-        ctx.save_for_backward(x, weight, bias)
-        ctx.activation = activation in ["silu", "swish"]
-        out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
-        return out
-
-    @staticmethod
-    def backward(ctx, dout):
-        x, weight, bias = ctx.saved_tensors
-        if dout.stride(2) != 1 and dout.stride(1) != 1:
-            dout = dout.contiguous()
-        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
-        # backward of conv1d with the backward of chunk).
-        # Here we just pass in None and dx will be allocated in the C++ code.
-        dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
-            x, weight, bias, dout, None, ctx.activation
-        )
-        return dx, dweight, dbias if bias is not None else None, None
-
-
-def causal_conv1d_fn(x, weight, bias=None, activation=None):
-    """
-    x: (batch, dim, seqlen)
-    weight: (dim, width)
-    bias: (dim,)
-    activation: either None or "silu" or "swish"
-
-    out: (batch, dim, seqlen)
-    """
-    return CausalConv1dFn.apply(x, weight, bias, activation)
-
-
-def causal_conv1d_ref(x, weight, bias=None, activation=None):
-    """
-    x: (batch, dim, seqlen)
-    weight: (dim, width)
-    bias: (dim,)
-
-    out: (batch, dim, seqlen)
-    """
-    if activation not in [None, "silu", "swish"]:
-        raise NotImplementedError("activation must be None, silu, or swish")
-    dtype_in = x.dtype
-    x = x.to(weight.dtype)
-    seqlen = x.shape[-1]
-    dim, width = weight.shape
-    out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
-    out = out[..., :seqlen]
-    return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
-
-
-def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
-    """
-    x: (batch, dim)
-    conv_state: (batch, dim, width)
-    weight: (dim, width)
-    bias: (dim,)
-
-    out: (batch, dim)
-    """
-    if activation not in [None, "silu", "swish"]:
-        raise NotImplementedError("activation must be None, silu, or swish")
-    activation = activation in ["silu", "swish"]
-    return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
-
-
-def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
-    """
-    x: (batch, dim)
-    conv_state: (batch, dim, width)
-    weight: (dim, width)
-    bias: (dim,)
-
-    out: (batch, dim)
-    """
-    if activation not in [None, "silu", "swish"]:
-        raise NotImplementedError("activation must be None, silu, or swish")
-    dtype_in = x.dtype
-    batch, dim = x.shape
-    width = weight.shape[1]
-    assert conv_state.shape == (batch, dim, width)
-    assert weight.shape == (dim, width)
-    conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
-    conv_state[:, :, -1] = x
-    out = torch.sum(conv_state * weight, dim=-1) # (B D)
-    if bias is not None:
-        out += bias
-    return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/causal-conv1d/csrc/causal_conv1d.cpp b/causal-conv1d/csrc/causal_conv1d.cpp
deleted file mode 100644
index 1c80516ac8599d4d80910a1d4d85c4c435cf1e4f..0000000000000000000000000000000000000000
--- a/causal-conv1d/csrc/causal_conv1d.cpp
+++ /dev/null
@@ -1,333 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#include <ATen/cuda/CUDAContext.h>
-#include <c10/cuda/CUDAGuard.h>
-#include <torch/extension.h>
-#include <vector>
-
-#include "causal_conv1d.h"
-
-#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
-
-#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...)                    \
-    if (ITYPE == at::ScalarType::Half) {                                            \
-        using input_t = at::Half;                                                   \
-        __VA_ARGS__();                                                              \
-    } else if (ITYPE == at::ScalarType::BFloat16) {                                 \
-        using input_t = at::BFloat16;                                               \
-        __VA_ARGS__();                                                              \
-    } else if (ITYPE == at::ScalarType::Float)  {                                   \
-        using input_t = float;                                                      \
-        __VA_ARGS__();                                                              \
-    } else {                                                                        \
-        AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
-    }
-
-#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...)                     \
-    if (WTYPE == at::ScalarType::Half) {                                             \
-        using weight_t = at::Half;                                                   \
-        __VA_ARGS__();                                                               \
-    } else if (WTYPE == at::ScalarType::BFloat16) {                                  \
-        using weight_t = at::BFloat16;                                               \
-        __VA_ARGS__();                                                               \
-    } else if (WTYPE == at::ScalarType::Float)  {                                    \
-        using weight_t = float;                                                      \
-        __VA_ARGS__();                                                               \
-    } else {                                                                         \
-        AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
-    }
-
-template<typename input_t, typename weight_t>
-void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
-template <typename input_t, typename weight_t>
-void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
-
-template<typename input_t, typename weight_t>
-void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
-template<typename input_t, typename weight_t>
-void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
-
-template<typename input_t, typename weight_t>
-void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
-
-void set_conv_params_fwd(ConvParamsBase &params,
-                         // sizes
-                         const size_t batch,
-                         const size_t dim,
-                         const size_t seqlen,
-                         const size_t width,
-                         // device pointers
-                         const at::Tensor x,
-                         const at::Tensor weight,
-                         const at::Tensor out,
-                         void* bias_ptr,
-                         bool silu_activation) {
-
-    // Reset the parameters
-    memset(&params, 0, sizeof(params));
-
-    params.batch = batch;
-    params.dim = dim;
-    params.seqlen = seqlen;
-    params.width = width;
-
-    params.silu_activation = silu_activation;
-
-    // Set the pointers and strides.
-    params.x_ptr = x.data_ptr();
-    params.weight_ptr = weight.data_ptr();
-    params.bias_ptr = bias_ptr;
-    params.out_ptr = out.data_ptr();
-    // All stride are in elements, not bytes.
-    params.x_batch_stride = x.stride(0);
-    params.x_c_stride = x.stride(1);
-    params.x_l_stride = x.stride(-1);
-    params.weight_c_stride = weight.stride(0);
-    params.weight_width_stride = weight.stride(1);
-    params.out_batch_stride = out.stride(0);
-    params.out_c_stride = out.stride(1);
-    params.out_l_stride = out.stride(-1);
-}
-
-
-void set_conv_params_bwd(ConvParamsBwd &params,
-                         // sizes
-                         const size_t batch,
-                         const size_t dim,
-                         const size_t seqlen,
-                         const size_t width,
-                         // device pointers
-                         const at::Tensor x,
-                         const at::Tensor weight,
-                         void* bias_ptr,
-                         const at::Tensor dout,
-                         const at::Tensor dx,
-                         const at::Tensor dweight,
-                         void* dbias_ptr,
-                         bool silu_activation) {
-    // Pass in "dout" instead of "out", we're not gonna use "out" at all.
-    set_conv_params_fwd(params, batch, dim, seqlen, width,
-                        x, weight, dout, bias_ptr, silu_activation);
-
-    // Set the pointers and strides.
-    params.dout_ptr = dout.data_ptr();
-    params.dx_ptr = dx.data_ptr();
-    params.dweight_ptr = dweight.data_ptr();
-    params.dbias_ptr = dbias_ptr;
-    // All stride are in elements, not bytes.
-    params.dout_batch_stride = dout.stride(0);
-    params.dout_c_stride = dout.stride(1);
-    params.dout_l_stride = dout.stride(2);
-    params.dweight_c_stride = dweight.stride(0);
-    params.dweight_width_stride = dweight.stride(1);
-    params.dx_batch_stride = dx.stride(0);
-    params.dx_c_stride = dx.stride(1);
-    params.dx_l_stride = dx.stride(2);
-}
-
-at::Tensor
-causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
-                  const c10::optional<at::Tensor> &bias_,
-                  bool silu_activation) {
-    auto input_type = x.scalar_type();
-    auto weight_type = weight.scalar_type();
-    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
-    TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
-
-    TORCH_CHECK(x.is_cuda());
-    TORCH_CHECK(weight.is_cuda());
-
-    const auto sizes = x.sizes();
-    const int batch_size = sizes[0];
-    const int dim = sizes[1];
-    const int seqlen = sizes[2];
-    const int width = weight.size(-1);
-
-    CHECK_SHAPE(x, batch_size, dim, seqlen);
-    CHECK_SHAPE(weight, dim, width);
-
-    TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
-    const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
-
-    if (is_channel_last) {
-        TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
-    }
-    TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
-
-
-    if (bias_.has_value()) {
-        auto bias = bias_.value();
-        TORCH_CHECK(bias.scalar_type() == weight_type);
-        TORCH_CHECK(bias.is_cuda());
-        TORCH_CHECK(bias.stride(-1) == 1);
-        CHECK_SHAPE(bias, dim);
-    }
-
-    at::Tensor out = torch::empty_like(x);
-
-    ConvParamsBase params;
-    set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
-                        bias_.has_value() ? bias_.value().data_ptr() : nullptr,
-                        silu_activation);
-
-    // Otherwise the kernel will be launched from cuda:0 device
-    // Cast to char to avoid compiler warning about narrowing
-    at::cuda::CUDAGuard device_guard{(char)x.get_device()};
-    auto stream = at::cuda::getCurrentCUDAStream().stream();
-    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
-        DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
-            if (!is_channel_last) {
-                causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
-            } else {
-                causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
-            }
-        });
-    });
-    return out;
-}
-
-std::vector<at::Tensor>
-causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
-                  const c10::optional<at::Tensor> &bias_,
-                  at::Tensor &dout,
-                  c10::optional<at::Tensor> &dx_,
-                  bool silu_activation) {
-    auto input_type = x.scalar_type();
-    auto weight_type = weight.scalar_type();
-    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
-    TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
-
-    TORCH_CHECK(x.is_cuda());
-    TORCH_CHECK(weight.is_cuda());
-    TORCH_CHECK(dout.is_cuda());
-
-    const auto sizes = x.sizes();
-    const int batch_size = sizes[0];
-    const int dim = sizes[1];
-    const int seqlen = sizes[2];
-    const int width = weight.size(-1);
-
-    TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
-
-    CHECK_SHAPE(x, batch_size, dim, seqlen);
-    CHECK_SHAPE(weight, dim, width);
-    CHECK_SHAPE(dout, batch_size, dim, seqlen);
-
-    TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
-    const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
-    if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
-    if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
-
-    if (bias_.has_value()) {
-        auto bias = bias_.value();
-        TORCH_CHECK(bias.scalar_type() == weight_type);
-        TORCH_CHECK(bias.is_cuda());
-        TORCH_CHECK(bias.stride(-1) == 1);
-        CHECK_SHAPE(bias, dim);
-    }
-
-    at::Tensor dx;
-    if (dx_.has_value()) {
-        dx = dx_.value();
-        TORCH_CHECK(dx.scalar_type() == input_type);
-        TORCH_CHECK(dx.is_cuda());
-        CHECK_SHAPE(dx, batch_size, dim, seqlen);
-        if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
-        if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
-    } else {
-        dx = torch::empty_like(x);
-    }
-
-    // Otherwise the kernel will be launched from cuda:0 device
-    // Cast to char to avoid compiler warning about narrowing
-    at::cuda::CUDAGuard device_guard{(char)x.get_device()};
-
-    at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
-    at::Tensor dbias;
-    if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); }
-
-    ConvParamsBwd params;
-    set_conv_params_bwd(params, batch_size, dim, seqlen, width,
-                        x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
-                        dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr,
-                        silu_activation);
-
-    auto stream = at::cuda::getCurrentCUDAStream().stream();
-    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
-        DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
-            if (!is_channel_last) {
-                causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
-            } else {
-                causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
-            }
-        });
-    });
-    return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias};
-}
-
-at::Tensor
-causal_conv1d_update(const at::Tensor &x,
-                     const at::Tensor &conv_state,
-                     const at::Tensor &weight,
-                     const c10::optional<at::Tensor> &bias_,
-                  bool silu_activation) {
-    auto input_type = x.scalar_type();
-    auto weight_type = weight.scalar_type();
-    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
-    TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
-    TORCH_CHECK(conv_state.scalar_type() == input_type);
-
-    TORCH_CHECK(x.is_cuda());
-    TORCH_CHECK(conv_state.is_cuda());
-    TORCH_CHECK(weight.is_cuda());
-
-    const auto sizes = x.sizes();
-    const int batch_size = sizes[0];
-    const int dim = sizes[1];
-    const int width = weight.size(-1);
-
-    CHECK_SHAPE(x, batch_size, dim);
-    CHECK_SHAPE(conv_state, batch_size, dim, width);
-    CHECK_SHAPE(weight, dim, width);
-
-    TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
-
-    if (bias_.has_value()) {
-        auto bias = bias_.value();
-        TORCH_CHECK(bias.scalar_type() == weight_type);
-        TORCH_CHECK(bias.is_cuda());
-        TORCH_CHECK(bias.stride(-1) == 1);
-        CHECK_SHAPE(bias, dim);
-    }
-
-    at::Tensor out = torch::empty_like(x);
-
-    ConvParamsBase params;
-    set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
-                        bias_.has_value() ? bias_.value().data_ptr() : nullptr,
-                        silu_activation);
-    params.conv_state_ptr = conv_state.data_ptr();
-    // All stride are in elements, not bytes.
-    params.conv_state_batch_stride = conv_state.stride(0);
-    params.conv_state_c_stride = conv_state.stride(1);
-    params.conv_state_l_stride = conv_state.stride(2);
-
-    // Otherwise the kernel will be launched from cuda:0 device
-    // Cast to char to avoid compiler warning about narrowing
-    at::cuda::CUDAGuard device_guard{(char)x.get_device()};
-    auto stream = at::cuda::getCurrentCUDAStream().stream();
-    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
-        DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
-            causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
-        });
-    });
-    return out;
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-    m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
-    m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
-    m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
-}
diff --git a/causal-conv1d/csrc/causal_conv1d.h b/causal-conv1d/csrc/causal_conv1d.h
deleted file mode 100644
index 844ed92cfc91a881e58fccfca001a13ebcc434cc..0000000000000000000000000000000000000000
--- a/causal-conv1d/csrc/causal_conv1d.h
+++ /dev/null
@@ -1,53 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#pragma once
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-struct ConvParamsBase {
-    using index_t = uint32_t;
-
-    int batch, dim, seqlen, width;
-    bool silu_activation;
-
-    index_t x_batch_stride;
-    index_t x_c_stride;
-    index_t x_l_stride;
-    index_t weight_c_stride;
-    index_t weight_width_stride;
-    index_t out_batch_stride;
-    index_t out_c_stride;
-    index_t out_l_stride;
-
-    index_t conv_state_batch_stride;
-    index_t conv_state_c_stride;
-    index_t conv_state_l_stride;
-
-    // Common data pointers.
-    void *__restrict__ x_ptr;
-    void *__restrict__ weight_ptr;
-    void *__restrict__ bias_ptr;
-    void *__restrict__ out_ptr;
-
-    void *__restrict__ conv_state_ptr;
-};
-
-struct ConvParamsBwd: public ConvParamsBase {
-    index_t dx_batch_stride;
-    index_t dx_c_stride;
-    index_t dx_l_stride;
-    index_t dweight_c_stride;
-    index_t dweight_width_stride;
-    index_t dout_batch_stride;
-    index_t dout_c_stride;
-    index_t dout_l_stride;
-
-    // Common data pointers.
-    void *__restrict__ dx_ptr;
-    void *__restrict__ dweight_ptr;
-    void *__restrict__ dbias_ptr;
-    void *__restrict__ dout_ptr;
-};
-
diff --git a/causal-conv1d/csrc/causal_conv1d_bwd.cu b/causal-conv1d/csrc/causal_conv1d_bwd.cu
deleted file mode 100644
index 66609750a30a86a284451871ca163d79a0529047..0000000000000000000000000000000000000000
--- a/causal-conv1d/csrc/causal_conv1d_bwd.cu
+++ /dev/null
@@ -1,525 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#include <c10/util/BFloat16.h>
-#include <c10/util/Half.h>
-#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
-
-#include <cub/block/block_load.cuh>
-#include <cub/block/block_store.cuh>
-#include <cub/block/block_reduce.cuh>
-
-#include "causal_conv1d.h"
-#include "causal_conv1d_common.h"
-#include "static_switch.h"
-
-template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
-struct Causal_conv1d_bwd_kernel_traits {
-    using input_t = input_t_;
-    using weight_t = weight_t_;
-    static constexpr int kNThreads = kNThreads_;
-    static constexpr int kWidth = kWidth_;
-    static constexpr bool kSiluAct = kSiluAct_;
-    static constexpr int kNBytes = sizeof(input_t);
-    static_assert(kNBytes == 2 || kNBytes == 4);
-    static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
-    static_assert(kWidth <= kNElts);
-    // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
-    // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
-    static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
-    static constexpr bool kIsVecLoad = kIsVecLoad_;
-    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
-    using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
-    using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
-    using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
-    using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
-    static constexpr int kSmemIOSize = kIsVecLoad
-        ? 0
-        : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
-    static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
-    static constexpr int kSmemSize = std::max({kSmemExchangeSize,
-            int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
-};
-
-template<typename Ktraits>
-__global__ __launch_bounds__(Ktraits::kNThreads)
-void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
-    constexpr int kWidth = Ktraits::kWidth;
-    constexpr int kNThreads = Ktraits::kNThreads;
-    constexpr bool kSiluAct = Ktraits::kSiluAct;
-    constexpr int kNElts = Ktraits::kNElts;
-    constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
-    constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
-    using input_t = typename Ktraits::input_t;
-    using vec_t = typename Ktraits::vec_t;
-    using weight_t = typename Ktraits::weight_t;
-
-    // Shared memory.
-    extern __shared__ char smem_[];
-    auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
-    auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
-    auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
-    auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
-    vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
-    vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
-    auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
-
-    const int tidx = threadIdx.x;
-    const int batch_id = blockIdx.x;
-    const int dim_id = blockIdx.y;
-    input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
-        + dim_id * params.x_c_stride;
-    weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
-    input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
-        + dim_id * params.dout_c_stride;
-    input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
-        + dim_id * params.dx_c_stride;
-    float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
-    float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
-
-    // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
-    if (tidx == 0) {
-        if constexpr (!kSiluAct) {
-            input_t zeros[kNElts] = {0};
-            smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
-        } else {
-            float zeros[kNElts] = {0};
-            #pragma unroll
-            for (int r = 0; r < kNExchangeRounds; ++r) {
-                smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
-            }
-        }
-    }
-
-    float weight_vals[kWidth];
-    #pragma unroll
-    for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
-
-    float dweight_vals[kWidth] = {0};
-    float dbias_val = 0;
-
-    constexpr int kChunkSize = kNThreads * kNElts;
-    const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
-    x += (n_chunks - 1) * kChunkSize;
-    dout += (n_chunks - 1) * kChunkSize;
-    dx += (n_chunks - 1) * kChunkSize;
-    for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
-        input_t x_vals_load[2 * kNElts] = {0};
-        input_t dout_vals_load[2 * kNElts] = {0};
-        if constexpr(kIsVecLoad) {
-            Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
-            Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
-        } else {
-            __syncthreads();
-            Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
-            __syncthreads();
-            Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
-        }
-        float dout_vals[2 * kNElts], x_vals[2 * kNElts];
-        if constexpr (!kSiluAct) {
-            __syncthreads();
-            // Thread 0 don't write yet, so that thread kNThreads - 1 can read
-            // the first elements of the next chunk.
-            if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
-            __syncthreads();
-            reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
-            __syncthreads();
-            // Now thread 0 can write the first elements of the current chunk.
-            if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
-            #pragma unroll
-            for (int i = 0; i < 2 * kNElts; ++i) {
-                dout_vals[i] = float(dout_vals_load[i]);
-                x_vals[i] = float(x_vals_load[i]);
-            }
-        } else {
-            if (tidx == 0 && chunk > 0) {
-                if constexpr(kIsVecLoad) {
-                    reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
-                } else {
-                    #pragma unroll
-                    for (int i = 0; i < kNElts; ++i) {
-                        if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
-                    }
-                }
-            }
-            __syncthreads();
-            smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
-            __syncthreads();
-            if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
-            #pragma unroll
-            for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
-            // Recompute the output
-            #pragma unroll
-            for (int i = 0; i < kNElts; ++i) {
-                float out_val = bias_val;
-                #pragma unroll
-                for (int w = 0; w < kWidth; ++w) {
-                    out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
-                }
-                float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
-                dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
-                               * (1.0f + out_val * (1.0f - out_sigmoid_val));
-            }
-            // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
-            // if input_t is 16 bits (since then we'd have 8 values of float)
-            __syncthreads();
-            // Thread 0 don't write yet, so that thread kNThreads - 1 can read
-            // the first elements of the next chunk.
-            if (tidx > 0) {
-                #pragma unroll
-                for (int r = 0; r < kNExchangeRounds; ++r) {
-                    smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
-                }
-            }
-            __syncthreads();
-            #pragma unroll
-            for (int r = 0; r < kNExchangeRounds; ++r) {
-                reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
-                    = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
-            }
-            __syncthreads();
-            // Now thread 0 can write the first elements of the current chunk.
-            if (tidx == 0) {
-                #pragma unroll
-                for (int r = 0; r < kNExchangeRounds; ++r) {
-                    smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
-                }
-            }
-        }
-        dout -= kChunkSize;
-        x -= kChunkSize;
-
-        #pragma unroll
-        for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
-
-        float dx_vals[kNElts] = {0};
-        #pragma unroll
-        for (int i = 0; i < kNElts; ++i) {
-            #pragma unroll
-            for (int w = 0; w < kWidth; ++w) {
-                dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
-            }
-        }
-
-        input_t dx_vals_store[kNElts];
-        #pragma unroll
-        for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
-        if constexpr(kIsVecLoad) {
-            Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
-        } else {
-            Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
-        }
-        dx -= kChunkSize;
-
-        #pragma unroll
-        for (int w = 0; w < kWidth; ++w) {
-            #pragma unroll
-            for (int i = 0; i < kNElts; ++i) {
-                dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
-            }
-        }
-    }
-
-    #pragma unroll
-    for (int w = 0; w < kWidth; ++w) {
-        __syncthreads();
-        dweight_vals[w] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
-        if (tidx == 0) {
-            atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
-        }
-    }
-    if (params.bias_ptr != nullptr) {
-        __syncthreads();
-        dbias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
-        if (tidx == 0) {
-            atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
-        }
-    }
-}
-
-template<int kNThreads, int kWidth, typename input_t, typename weight_t>
-void causal_conv1d_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
-    static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
-    BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
-        BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
-            using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
-            constexpr int kSmemSize = Ktraits::kSmemSize;
-            dim3 grid(params.batch, params.dim);
-            auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
-            if (kSmemSize >= 48 * 1024) {
-                C10_CUDA_CHECK(cudaFuncSetAttribute(
-                    kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
-                }
-            kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
-            C10_CUDA_KERNEL_LAUNCH_CHECK();
-        });
-    });
-}
-
-template<typename input_t, typename weight_t>
-void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
-    if (params.width == 2) {
-        causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
-    } else if (params.width == 3) {
-        causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
-    } else if (params.width == 4) {
-        causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
-    }
-}
-
-template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
-struct Causal_conv1d_channellast_bwd_kernel_traits {
-    // The cache line is 128 bytes, and we try to read 16 bytes per thread.
-    // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
-    // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
-    // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
-    using input_t = input_t_;
-    using weight_t = weight_t_;
-    static constexpr bool kSiluAct = kSiluAct_;
-    static constexpr int kNThreads = kNThreads_;
-    static_assert(kNThreads % 32 == 0);
-    static constexpr int kNWarps = kNThreads / 32;
-    static constexpr int kWidth = kWidth_;
-    static constexpr int kChunkSizeL = kChunkSizeL_;
-    static constexpr int kNBytes = sizeof(input_t);
-    static_assert(kNBytes == 2 || kNBytes == 4);
-    static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
-    static constexpr int kNEltsPerRow = 128 / kNBytes;
-    static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts;  // Always 8 for now
-    static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
-    static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow;  // Always 4 for now
-    static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
-    static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
-    static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
-    static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
-    static constexpr bool kIsVecLoad = kIsVecLoad_;
-    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
-    // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
-    // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
-    //                                            sizeof(typename BlockStoreT::TempStorage)});
-    // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
-};
-
-template<typename Ktraits>
-__global__ __launch_bounds__(Ktraits::kNThreads)
-void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
-    constexpr int kWidth = Ktraits::kWidth;
-    constexpr int kNThreads = Ktraits::kNThreads;
-    constexpr bool kSiluAct = Ktraits::kSiluAct;
-    constexpr int kNElts = Ktraits::kNElts;
-    constexpr int kNWarp = Ktraits::kNWarps;
-    constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
-    constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
-    constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
-    constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
-    using input_t = typename Ktraits::input_t;
-    using vec_t = typename Ktraits::vec_t;
-    using weight_t = typename Ktraits::weight_t;
-
-    // Shared memory.
-    __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
-    __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
-
-    const int tid = threadIdx.x;
-    const int l_idx = tid / kNThreadsPerC;
-    const int c_idx = tid % kNThreadsPerC;
-    const int batch_id = blockIdx.x;
-    const int chunk_l_id = blockIdx.y;
-    const int chunk_c_id = blockIdx.z;
-    input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
-        + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
-    weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
-        + chunk_c_id * kChunkSizeC * params.weight_c_stride;
-    input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
-        + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
-    input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
-        + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
-    float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
-        + chunk_c_id * kChunkSizeC * params.dweight_c_stride;
-
-    #pragma unroll
-    for (int l = 0; l < Ktraits::kNLoads; ++l) {
-        input_t dout_vals_load[kNElts] = {0};
-        input_t x_vals_load[kNElts] = {0};
-        if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
-            && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
-            reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
-            reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
-        }
-        reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
-        reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
-    }
-    // Load the elements from the previous chunk or next chunk that are needed for convolution.
-    if (l_idx < kWidth - 1) {
-        input_t dout_vals_load[kNElts] = {0};
-        input_t x_vals_load[kNElts] = {0};
-        if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
-            && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
-            reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
-        }
-        if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
-            && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
-            && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
-            reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
-        }
-        reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
-        reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
-    }
-    // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
-    if constexpr (kSiluAct) {
-        if (l_idx < kWidth - 1) {
-            input_t x_vals_load[kNElts] = {0};
-            if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
-                && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
-                reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
-            }
-            reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
-        }
-    }
-
-    __syncthreads();
-
-    constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
-    static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
-    constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
-    static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
-    // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
-    static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
-    static_assert((kLPerThread & (kLPerThread - 1)) == 0);
-    static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
-    static_assert(kNThreadsPerRow <= 32);
-
-    const int row_idx = tid / kNThreadsPerRow;
-    const int col_idx = tid % kNThreadsPerRow;
-
-    float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
-    float weight_vals[kWidth] = {0};
-    if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
-        #pragma unroll
-        for (int w = 0; w < kWidth; ++w) {
-            weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
-        }
-    }
-    float dout_vals[kLPerThread + kWidth - 1];
-    float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
-    #pragma unroll
-    for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
-        dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
-        x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
-    }
-
-    if constexpr (kSiluAct) {  // Recompute the output
-        #pragma unroll
-        for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
-            x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
-        }
-        #pragma unroll
-        for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
-            float out_val = bias_val;
-            #pragma unroll
-            for (int w = 0; w < kWidth; ++w) { out_val += weight_vals[w] * x_vals[i + w]; }
-            float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
-            dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
-        }
-    }
-
-    float dweight_vals[kWidth] = {0};
-    SumOp<float> sum_op;
-    #pragma unroll
-    for (int w = 0; w < kWidth; ++w) {
-        #pragma unroll
-        for (int i = 0; i < kLPerThread; ++i) { dweight_vals[w] += x_vals[i + w] * dout_vals[i]; }
-        dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
-        if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
-            atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
-        }
-    }
-
-    if (params.bias_ptr != nullptr) {
-        float dbias_val = 0.f;
-        for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
-        dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
-        if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
-            atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
-        }
-    }
-
-    float dx_vals[kLPerThread] = {0};
-    #pragma unroll
-    for (int i = 0; i < kLPerThread; ++i) {
-        #pragma unroll
-        for (int w = 0; w < kWidth; ++w) { dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; }
-    }
-    // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
-    __syncwarp();
-    #pragma unroll
-    for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
-    __syncthreads();
-
-    #pragma unroll
-    for (int l = 0; l < Ktraits::kNLoads; ++l) {
-        input_t dx_vals_store[kNElts];
-        reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
-        if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
-            && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
-            *reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
-        }
-    }
-
-}
-
-template<int kNThreads, int kWidth, typename input_t, typename weight_t>
-void causal_conv1d_channellast_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
-    BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
-        using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, 64, kSiluAct, true, input_t, weight_t>;
-        // constexpr int kSmemSize = Ktraits::kSmemSize;
-        constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
-        constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
-        const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
-        const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
-        dim3 grid(params.batch, n_chunks_L, n_chunks_C);
-        dim3 block(Ktraits::kNThreads);
-        auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits>;
-        // if (kSmemSize >= 48 * 1024) {
-        //     C10_CUDA_CHECK(cudaFuncSetAttribute(
-        //         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
-        //     }
-        // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
-        kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
-        C10_CUDA_KERNEL_LAUNCH_CHECK();
-    });
-}
-
-template<typename input_t, typename weight_t>
-void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
-    if (params.width == 2) {
-        causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
-    } else if (params.width == 3) {
-        causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
-    } else if (params.width == 4) {
-        causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
-    }
-}
-
-template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
-
-template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
-template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/causal-conv1d/csrc/causal_conv1d_common.h b/causal-conv1d/csrc/causal_conv1d_common.h
deleted file mode 100644
index 8dd6a333b52163986c085f71475709706ce8f9c3..0000000000000000000000000000000000000000
--- a/causal-conv1d/csrc/causal_conv1d_common.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#pragma once
-
-#include <cuda_bf16.h>
-#include <cuda_fp16.h>
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template<int BYTES> struct BytesToType {};
-
-template<> struct BytesToType<16> {
-    using Type = uint4;
-    static_assert(sizeof(Type) == 16);
-};
-
-template<> struct BytesToType<8> {
-    using Type = uint64_t;
-    static_assert(sizeof(Type) == 8);
-};
-
-template<> struct BytesToType<4> {
-    using Type = uint32_t;
-    static_assert(sizeof(Type) == 4);
-};
-
-template<> struct BytesToType<2> {
-    using Type = uint16_t;
-    static_assert(sizeof(Type) == 2);
-};
-
-template<> struct BytesToType<1> {
-    using Type = uint8_t;
-    static_assert(sizeof(Type) == 1);
-};
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template<typename T>
-struct SumOp {
-__device__ inline T operator()(T const & x, T const & y) { return x + y; }
-};
-
-template<int THREADS>
-struct Allreduce {
-    static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
-    template<typename T, typename Operator>
-    static __device__ inline T run(T x, Operator &op) {
-        constexpr int OFFSET = THREADS / 2;
-        x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
-        return Allreduce<OFFSET>::run(x, op);
-    }
-};
-
-template<>
-struct Allreduce<2> {
-template<typename T, typename Operator>
-static __device__ inline T run(T x, Operator &op) {
-    x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
-    return x;
-}
-};
diff --git a/causal-conv1d/csrc/causal_conv1d_fwd.cu b/causal-conv1d/csrc/causal_conv1d_fwd.cu
deleted file mode 100644
index 74a1459f88a87ef427075a25e5081899e382efc0..0000000000000000000000000000000000000000
--- a/causal-conv1d/csrc/causal_conv1d_fwd.cu
+++ /dev/null
@@ -1,350 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#include <c10/util/BFloat16.h>
-#include <c10/util/Half.h>
-#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
-
-#include <cub/block/block_load.cuh>
-#include <cub/block/block_store.cuh>
-
-#include "causal_conv1d.h"
-#include "causal_conv1d_common.h"
-#include "static_switch.h"
-
-template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
-struct Causal_conv1d_fwd_kernel_traits {
-    using input_t = input_t_;
-    using weight_t = weight_t_;
-    static constexpr int kNThreads = kNThreads_;
-    static constexpr int kWidth = kWidth_;
-    static constexpr int kNBytes = sizeof(input_t);
-    static_assert(kNBytes == 2 || kNBytes == 4);
-    static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
-    static_assert(kWidth <= kNElts);
-    static constexpr bool kIsVecLoad = kIsVecLoad_;
-    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
-    using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
-    using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
-    using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
-    static constexpr int kSmemIOSize = kIsVecLoad
-        ? 0
-        : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
-    static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
-    static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
-};
-
-template<typename Ktraits>
-__global__ __launch_bounds__(Ktraits::kNThreads)
-void causal_conv1d_fwd_kernel(ConvParamsBase params) {
-    constexpr int kWidth = Ktraits::kWidth;
-    constexpr int kNThreads = Ktraits::kNThreads;
-    constexpr int kNElts = Ktraits::kNElts;
-    constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
-    using input_t = typename Ktraits::input_t;
-    using vec_t = typename Ktraits::vec_t;
-    using weight_t = typename Ktraits::weight_t;
-
-    // Shared memory.
-    extern __shared__ char smem_[];
-    auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
-    auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
-    auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
-    auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
-    vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
-
-    const int tidx = threadIdx.x;
-    const int batch_id = blockIdx.x;
-    const int channel_id = blockIdx.y;
-    input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
-        + channel_id * params.x_c_stride;
-    weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
-    input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
-        + channel_id * params.out_c_stride;
-    float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
-
-    // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
-    if (tidx == 0) {
-        input_t zeros[kNElts] = {0};
-        smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
-    }
-
-    float weight_vals[kWidth];
-    #pragma unroll
-    for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
-
-    constexpr int kChunkSize = kNThreads * kNElts;
-    const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
-    for (int chunk = 0; chunk < n_chunks; ++chunk) {
-        input_t x_vals_load[2 * kNElts] = {0};
-        if constexpr(kIsVecLoad) {
-            Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
-        } else {
-            __syncthreads();
-            Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
-        }
-        x += kChunkSize;
-        __syncthreads();
-        // Thread kNThreads - 1 don't write yet, so that thread 0 can read
-        // the last elements of the previous chunk.
-        if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
-        __syncthreads();
-        reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
-        __syncthreads();
-        // Now thread kNThreads - 1 can write the last elements of the current chunk.
-        if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
-
-        float x_vals[2 * kNElts];
-        #pragma unroll
-        for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
-
-        float out_vals[kNElts];
-        #pragma unroll
-        for (int i = 0; i < kNElts; ++i) {
-            out_vals[i] = bias_val;
-            #pragma unroll
-            for (int w = 0; w < kWidth; ++w) {
-                out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
-            }
-        }
-
-        if (params.silu_activation) {
-            #pragma unroll
-            for (int i = 0; i < kNElts; ++i) {
-                out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
-            }
-        }
-
-        input_t out_vals_store[kNElts];
-        #pragma unroll
-        for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
-        if constexpr(kIsVecLoad) {
-            Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
-        } else {
-            Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
-        }
-        out += kChunkSize;
-    }
-}
-
-template<int kNThreads, int kWidth, typename input_t, typename weight_t>
-void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
-    static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
-    BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
-        using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
-        constexpr int kSmemSize = Ktraits::kSmemSize;
-        dim3 grid(params.batch, params.dim);
-        auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
-        if (kSmemSize >= 48 * 1024) {
-            C10_CUDA_CHECK(cudaFuncSetAttribute(
-                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
-            }
-        kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
-        C10_CUDA_KERNEL_LAUNCH_CHECK();
-    });
-}
-
-template<typename input_t, typename weight_t>
-void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
-    if (params.width == 2) {
-        causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
-    } else if (params.width == 3) {
-        causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
-    } else if (params.width == 4) {
-        causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
-    }
-}
-
-template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
-struct Causal_conv1d_channellast_fwd_kernel_traits {
-    // The cache line is 128 bytes, and we try to read 16 bytes per thread.
-    // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
-    // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
-    // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
-    using input_t = input_t_;
-    using weight_t = weight_t_;
-    static constexpr int kNThreads = kNThreads_;
-    static_assert(kNThreads % 32 == 0);
-    static constexpr int kNWarps = kNThreads / 32;
-    static constexpr int kWidth = kWidth_;
-    static constexpr int kChunkSizeL = kChunkSizeL_;
-    static constexpr int kNBytes = sizeof(input_t);
-    static_assert(kNBytes == 2 || kNBytes == 4);
-    static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
-    static constexpr int kNEltsPerRow = 128 / kNBytes;
-    static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts;  // Always 8 for now
-    static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
-    static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow;  // Always 4 for now
-    static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
-    static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
-    static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
-    static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
-    static constexpr bool kIsVecLoad = kIsVecLoad_;
-    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
-    // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
-    // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
-    //                                            sizeof(typename BlockStoreT::TempStorage)});
-    // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
-};
-
-template<typename Ktraits>
-__global__ __launch_bounds__(Ktraits::kNThreads)
-void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
-    constexpr int kWidth = Ktraits::kWidth;
-    constexpr int kNThreads = Ktraits::kNThreads;
-    constexpr int kNElts = Ktraits::kNElts;
-    constexpr int kNWarp = Ktraits::kNWarps;
-    constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
-    constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
-    constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
-    constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
-    using input_t = typename Ktraits::input_t;
-    using vec_t = typename Ktraits::vec_t;
-    using weight_t = typename Ktraits::weight_t;
-
-    // Shared memory.
-    __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
-
-    const int tid = threadIdx.x;
-    const int l_idx = tid / kNThreadsPerC;
-    const int c_idx = tid % kNThreadsPerC;
-    const int batch_id = blockIdx.x;
-    const int chunk_l_id = blockIdx.y;
-    const int chunk_c_id = blockIdx.z;
-    input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
-        + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
-    weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
-        + chunk_c_id * kChunkSizeC * params.weight_c_stride;
-    input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
-        + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
-
-    #pragma unroll
-    for (int l = 0; l < Ktraits::kNLoads; ++l) {
-        input_t x_vals_load[kNElts] = {0};
-        if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
-            && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
-            reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
-        }
-        reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
-    }
-    // Load the elements from the previous chunk that are needed for convolution.
-    if (l_idx < kWidth - 1) {
-        input_t x_vals_load[kNElts] = {0};
-        if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
-            && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
-            && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
-            reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
-        }
-        reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
-    }
-
-    __syncthreads();
-
-    constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
-    static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
-    constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
-    static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
-    // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
-    static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
-    static_assert((kLPerThread & (kLPerThread - 1)) == 0);
-    static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
-    static_assert(kNThreadsPerRow <= 32);
-
-    const int row_idx = tid / kNThreadsPerRow;
-    const int col_idx = tid % kNThreadsPerRow;
-
-    float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
-    float weight_vals[kWidth] = {0};
-    if (chunk_c_id + kChunkSizeC + row_idx < params.dim) {
-        #pragma unroll
-        for (int w = 0; w < kWidth; ++w) {
-            weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
-        }
-    }
-    float x_vals[kWidth - 1 + kLPerThread];
-    #pragma unroll
-    for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
-        x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
-    }
-
-    float out_vals[kLPerThread];
-    #pragma unroll
-    for (int i = 0; i < kLPerThread; ++i) {
-        out_vals[i] = bias_val;
-        #pragma unroll
-        for (int w = 0; w < kWidth; ++w) { out_vals[i] += weight_vals[w] * x_vals[i + w]; }
-        if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
-    }
-
-    // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads.
-    __syncwarp();
-    #pragma unroll
-    for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
-    __syncthreads();
-
-    #pragma unroll
-    for (int l = 0; l < Ktraits::kNLoads; ++l) {
-        input_t out_vals_store[kNElts];
-        reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
-        if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
-            && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
-            *reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
-        }
-    }
-
-}
-
-template<int kNThreads, int kWidth, typename input_t, typename weight_t>
-void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
-    using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
-    // constexpr int kSmemSize = Ktraits::kSmemSize;
-    constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
-    constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
-    const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
-    const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
-    // printf("n_chunks_L: %d, n_chunks_C: %d\n", n_chunks_L, n_chunks_C);
-    dim3 grid(params.batch, n_chunks_L, n_chunks_C);
-    dim3 block(Ktraits::kNThreads);
-    auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits>;
-    // if (kSmemSize >= 48 * 1024) {
-    //     C10_CUDA_CHECK(cudaFuncSetAttribute(
-    //         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
-    //     }
-    // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
-    kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
-}
-
-template<typename input_t, typename weight_t>
-void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
-    if (params.width == 2) {
-        causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
-    } else if (params.width == 3) {
-        causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
-    } else if (params.width == 4) {
-        causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
-    }
-}
-
-template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
-
-template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/causal-conv1d/csrc/causal_conv1d_update.cu b/causal-conv1d/csrc/causal_conv1d_update.cu
deleted file mode 100644
index 713e0ac883853491f9bdb0015b578657c228c1e7..0000000000000000000000000000000000000000
--- a/causal-conv1d/csrc/causal_conv1d_update.cu
+++ /dev/null
@@ -1,96 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#include <c10/util/BFloat16.h>
-#include <c10/util/Half.h>
-#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
-
-#include <cub/block/block_load.cuh>
-#include <cub/block/block_store.cuh>
-
-#include "causal_conv1d.h"
-#include "causal_conv1d_common.h"
-#include "static_switch.h"
-
-template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
-struct Causal_conv1d_update_kernel_traits {
-    using input_t = input_t_;
-    using weight_t = weight_t_;
-    static constexpr int kNThreads = kNThreads_;
-    static constexpr int kWidth = kWidth_;
-    static constexpr int kNBytes = sizeof(input_t);
-    static_assert(kNBytes == 2 || kNBytes == 4);
-};
-
-template<typename Ktraits>
-__global__ __launch_bounds__(Ktraits::kNThreads)
-void causal_conv1d_update_kernel(ConvParamsBase params) {
-    constexpr int kWidth = Ktraits::kWidth;
-    constexpr int kNThreads = Ktraits::kNThreads;
-    using input_t = typename Ktraits::input_t;
-    using weight_t = typename Ktraits::weight_t;
-
-    const int tidx = threadIdx.x;
-    const int batch_id = blockIdx.x;
-    const int channel_id = blockIdx.y * kNThreads + tidx;
-    input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
-        + channel_id * params.x_c_stride;
-    input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
-        + channel_id * params.conv_state_c_stride;
-    weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
-    input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
-        + channel_id * params.out_c_stride;
-    float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
-
-    float weight_vals[kWidth] = {0};
-    if (channel_id < params.dim) {
-        #pragma unroll
-        for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
-    }
-
-    float x_vals[kWidth] = {0};
-    if (channel_id < params.dim) {
-        #pragma unroll
-        for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
-        x_vals[kWidth - 1] = float(x[0]);
-        #pragma unroll
-        for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
-    }
-
-    float out_val = bias_val;
-    #pragma unroll
-    for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
-    if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
-    if (channel_id < params.dim) { out[0] = input_t(out_val); }
-}
-
-template<int kNThreads, int kWidth, typename input_t, typename weight_t>
-void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
-    using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
-    dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
-    auto kernel = &causal_conv1d_update_kernel<Ktraits>;
-    kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
-}
-
-template<typename input_t, typename weight_t>
-void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
-    if (params.width == 2) {
-        causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
-    } else if (params.width == 3) {
-        causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
-    } else if (params.width == 4) {
-        causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
-    }
-}
-
-template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
-template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/causal-conv1d/csrc/static_switch.h b/causal-conv1d/csrc/static_switch.h
deleted file mode 100644
index 0f4ad3eb62235443d15c454b6691c2ec63645219..0000000000000000000000000000000000000000
--- a/causal-conv1d/csrc/static_switch.h
+++ /dev/null
@@ -1,25 +0,0 @@
-// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
-// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
-
-#pragma once
-
-/// @param COND       - a boolean expression to switch by
-/// @param CONST_NAME - a name given for the constexpr bool variable.
-/// @param ...       - code to execute for true and false
-///
-/// Usage:
-/// ```
-/// BOOL_SWITCH(flag, BoolConst, [&] {
-///     some_function<BoolConst>(...);
-/// });
-/// ```
-#define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
-    [&] {                                                                            \
-        if (COND) {                                                                  \
-            static constexpr bool CONST_NAME = true;                                 \
-            return __VA_ARGS__();                                                    \
-        } else {                                                                     \
-            static constexpr bool CONST_NAME = false;                                \
-            return __VA_ARGS__();                                                    \
-        }                                                                            \
-    }()
diff --git a/causal-conv1d/setup.py b/causal-conv1d/setup.py
deleted file mode 100644
index 12e36bf988215a4c536278026e6f4401e66534da..0000000000000000000000000000000000000000
--- a/causal-conv1d/setup.py
+++ /dev/null
@@ -1,264 +0,0 @@
-# Copyright (c) 2023, Tri Dao.
-import sys
-import warnings
-import os
-import re
-import ast
-from pathlib import Path
-from packaging.version import parse, Version
-import platform
-
-from setuptools import setup, find_packages
-import subprocess
-
-import urllib.request
-import urllib.error
-from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
-
-import torch
-from torch.utils.cpp_extension import (
-    BuildExtension,
-    CppExtension,
-    CUDAExtension,
-    CUDA_HOME,
-)
-
-
-with open("README.md", "r", encoding="utf-8") as fh:
-    long_description = fh.read()
-
-
-# ninja build does not work unless include_dirs are abs path
-this_dir = os.path.dirname(os.path.abspath(__file__))
-
-PACKAGE_NAME = "causal_conv1d"
-
-BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
-
-# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
-# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
-FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
-SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
-# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
-FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
-
-
-def get_platform():
-    """
-    Returns the platform name as used in wheel filenames.
-    """
-    if sys.platform.startswith("linux"):
-        return "linux_x86_64"
-    elif sys.platform == "darwin":
-        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
-        return f"macosx_{mac_version}_x86_64"
-    elif sys.platform == "win32":
-        return "win_amd64"
-    else:
-        raise ValueError("Unsupported platform: {}".format(sys.platform))
-
-
-def get_cuda_bare_metal_version(cuda_dir):
-    raw_output = subprocess.check_output(
-        [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
-    )
-    output = raw_output.split()
-    release_idx = output.index("release") + 1
-    bare_metal_version = parse(output[release_idx].split(",")[0])
-
-    return raw_output, bare_metal_version
-
-
-def check_if_cuda_home_none(global_option: str) -> None:
-    if CUDA_HOME is not None:
-        return
-    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
-    # in that case.
-    warnings.warn(
-        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
-        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
-        "only images whose names contain 'devel' will provide nvcc."
-    )
-
-
-def append_nvcc_threads(nvcc_extra_args):
-    return nvcc_extra_args + ["--threads", "4"]
-
-
-cmdclass = {}
-ext_modules = []
-
-if not SKIP_CUDA_BUILD:
-    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
-    TORCH_MAJOR = int(torch.__version__.split(".")[0])
-    TORCH_MINOR = int(torch.__version__.split(".")[1])
-
-    check_if_cuda_home_none("causal_conv1d")
-    # Check, if CUDA11 is installed for compute capability 8.0
-    cc_flag = []
-    if CUDA_HOME is not None:
-        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
-        if bare_metal_version < Version("11.6"):
-            raise RuntimeError(
-                "causal_conv1d is only supported on CUDA 11.6 and above.  "
-                "Note: make sure nvcc has a supported version by running nvcc -V."
-            )
-
-    cc_flag.append("-gencode")
-    cc_flag.append("arch=compute_70,code=sm_70")
-    cc_flag.append("-gencode")
-    cc_flag.append("arch=compute_80,code=sm_80")
-    if bare_metal_version >= Version("11.8"):
-        cc_flag.append("-gencode")
-        cc_flag.append("arch=compute_90,code=sm_90")
-
-    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
-    # torch._C._GLIBCXX_USE_CXX11_ABI
-    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
-    if FORCE_CXX11_ABI:
-        torch._C._GLIBCXX_USE_CXX11_ABI = True
-
-    ext_modules.append(
-        CUDAExtension(
-            name="causal_conv1d_cuda",
-            sources=[
-                "csrc/causal_conv1d.cpp",
-                "csrc/causal_conv1d_fwd.cu",
-                "csrc/causal_conv1d_bwd.cu",
-                "csrc/causal_conv1d_update.cu",
-            ],
-            extra_compile_args={
-                "cxx": ["-O3"],
-                "nvcc": append_nvcc_threads(
-                    [
-                        "-O3",
-                        "-U__CUDA_NO_HALF_OPERATORS__",
-                        "-U__CUDA_NO_HALF_CONVERSIONS__",
-                        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
-                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
-                        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
-                        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
-                        "--expt-relaxed-constexpr",
-                        "--expt-extended-lambda",
-                        "--use_fast_math",
-                        "--ptxas-options=-v",
-                        "-lineinfo",
-                    ]
-                    + cc_flag
-                ),
-            },
-            include_dirs=[this_dir],
-        )
-    )
-
-
-def get_package_version():
-    with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
-        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
-    public_version = ast.literal_eval(version_match.group(1))
-    local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
-    if local_version:
-        return f"{public_version}+{local_version}"
-    else:
-        return str(public_version)
-
-
-def get_wheel_url():
-    # Determine the version numbers that will be used to determine the correct wheel
-    # We're using the CUDA version used to build torch, not the one currently installed
-    # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
-    torch_cuda_version = parse(torch.version.cuda)
-    torch_version_raw = parse(torch.__version__)
-    # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
-    # to save CI time. Minor versions should be compatible.
-    torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
-    python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
-    platform_name = get_platform()
-    causal_conv1d_version = get_package_version()
-    # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
-    cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
-    torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
-    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
-
-    # Determine wheel URL based on CUDA version, torch version, python version and OS
-    wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
-    wheel_url = BASE_WHEEL_URL.format(
-        tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
-    )
-    return wheel_url, wheel_filename
-
-
-class CachedWheelsCommand(_bdist_wheel):
-    """
-    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
-    find an existing wheel (which is currently the case for all installs). We use
-    the environment parameters to detect whether there is already a pre-built version of a compatible
-    wheel available and short-circuits the standard full build pipeline.
-    """
-
-    def run(self):
-        if FORCE_BUILD:
-            return super().run()
-
-        wheel_url, wheel_filename = get_wheel_url()
-        print("Guessing wheel URL: ", wheel_url)
-        try:
-            urllib.request.urlretrieve(wheel_url, wheel_filename)
-
-            # Make the archive
-            # Lifted from the root wheel processing command
-            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
-            if not os.path.exists(self.dist_dir):
-                os.makedirs(self.dist_dir)
-
-            impl_tag, abi_tag, plat_tag = self.get_tag()
-            archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
-
-            wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
-            print("Raw wheel path", wheel_path)
-            os.rename(wheel_filename, wheel_path)
-        except urllib.error.HTTPError:
-            print("Precompiled wheel not found. Building from source...")
-            # If the wheel could not be downloaded, build from source
-            super().run()
-
-
-setup(
-    name=PACKAGE_NAME,
-    version=get_package_version(),
-    packages=find_packages(
-        exclude=(
-            "build",
-            "csrc",
-            "include",
-            "tests",
-            "dist",
-            "docs",
-            "benchmarks",
-            "causal_conv1d.egg-info",
-        )
-    ),
-    author="Tri Dao",
-    author_email="tri@tridao.me",
-    description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
-    long_description=long_description,
-    long_description_content_type="text/markdown",
-    url="https://github.com/Dao-AILab/causal-conv1d",
-    classifiers=[
-        "Programming Language :: Python :: 3",
-        "License :: OSI Approved :: BSD License",
-        "Operating System :: Unix",
-    ],
-    ext_modules=ext_modules,
-    cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
-    if ext_modules
-    else {
-        "bdist_wheel": CachedWheelsCommand,
-    },
-    python_requires=">=3.7",
-    install_requires=[
-        "torch",
-        "packaging",
-        "ninja",
-    ],
-)
diff --git a/causal-conv1d/tests/test_causal_conv1d.py b/causal-conv1d/tests/test_causal_conv1d.py
deleted file mode 100644
index 6e5985cfb0582e6656afb1d8b5c1de78f24f4276..0000000000000000000000000000000000000000
--- a/causal-conv1d/tests/test_causal_conv1d.py
+++ /dev/null
@@ -1,173 +0,0 @@
-# Copyright (C) 2023, Tri Dao.
-
-import math
-
-import torch
-import pytest
-
-from einops import rearrange
-
-from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
-from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
-
-
-@pytest.mark.parametrize("channel_last", [False, True])
-# @pytest.mark.parametrize('channel_last', [True])
-@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
-# @pytest.mark.parametrize('itype', [torch.float16])
-@pytest.mark.parametrize("silu_activation", [False, True])
-# @pytest.mark.parametrize('silu_activation', [True])
-@pytest.mark.parametrize("has_bias", [False, True])
-# @pytest.mark.parametrize('has_bias', [True])
-@pytest.mark.parametrize("width", [2, 3, 4])
-# @pytest.mark.parametrize('width', [2])
-@pytest.mark.parametrize(
-    "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
-)
-# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
-# @pytest.mark.parametrize('seqlen', [128])
-def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last):
-    device = "cuda"
-    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
-    if itype == torch.bfloat16:
-        rtol, atol = 1e-2, 5e-2
-    rtolw, atolw = (1e-3, 1e-3)
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 2
-    # batch_size = 1
-    dim = 4096 + 32  # Try dim not divisible by 64
-    # dim = 64
-    if not channel_last:
-        x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
-    else:
-        x = rearrange(
-            torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
-        ).requires_grad_()
-    weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
-    if has_bias:
-        bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    else:
-        bias = None
-    x_ref = x.detach().clone().requires_grad_()
-    weight_ref = weight.detach().clone().requires_grad_()
-    bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
-    activation = None if not silu_activation else "silu"
-    out = causal_conv1d_fn(x, weight, bias, activation=activation)
-    out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation)
-
-    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
-    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
-    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
-
-    g = torch.randn_like(out)
-    out_ref.backward(g)
-    out.backward(g)
-
-    print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
-    print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
-    if has_bias:
-        print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
-
-    assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
-    assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
-    if has_bias:
-        assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
-
-
-@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
-# @pytest.mark.parametrize('itype', [torch.float16])
-@pytest.mark.parametrize("silu_activation", [False, True])
-# @pytest.mark.parametrize('silu_activation', [False])
-@pytest.mark.parametrize("has_bias", [False, True])
-# @pytest.mark.parametrize('has_bias', [True])
-@pytest.mark.parametrize("width", [2, 3, 4])
-# @pytest.mark.parametrize('width', [2])
-@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
-# @pytest.mark.parametrize("dim", [2048])
-def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
-    device = "cuda"
-    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
-    if itype == torch.bfloat16:
-        rtol, atol = 1e-2, 5e-2
-    rtolw, atolw = (1e-3, 1e-3)
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 2
-    # batch_size = 1
-    # dim = 64
-    x = torch.randn(batch_size, dim, device=device, dtype=itype)
-    conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype)
-    weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
-    if has_bias:
-        bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    else:
-        bias = None
-    conv_state_ref = conv_state.detach().clone()
-    activation = None if not silu_activation else "silu"
-    out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
-    out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation)
-
-    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
-    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
-    assert torch.equal(conv_state, conv_state_ref)
-    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
-
-
-# @pytest.mark.parametrize("channel_last", [False, True])
-@pytest.mark.parametrize('channel_last', [True])
-# @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
-@pytest.mark.parametrize('itype', [torch.bfloat16])
-# @pytest.mark.parametrize("silu_activation", [False, True])
-@pytest.mark.parametrize('silu_activation', [True])
-# @pytest.mark.parametrize("has_bias", [False, True])
-@pytest.mark.parametrize('has_bias', [True])
-# @pytest.mark.parametrize("width", [2, 3, 4])
-@pytest.mark.parametrize('width', [4])
-@pytest.mark.parametrize(
-    # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
-    "seqlen", [2048]
-)
-# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
-# @pytest.mark.parametrize('seqlen', [128])
-def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
-    device = "cuda"
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 2
-    # batch_size = 1
-    dim = 4096 + 32  # Try dim not divisible by 64
-    # dim = 64
-    if not channel_last:
-        x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
-    else:
-        x = rearrange(
-            torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
-        ).requires_grad_()
-    weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
-    if has_bias:
-        bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    else:
-        bias = None
-    activation = None if not silu_activation else "silu"
-    out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
-    g = torch.randn_like(out0)
-    dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
-    dw_atol = 1e-4
-    db_atol = 1e-4
-
-    for i in range(10000):
-        out = causal_conv1d_fn(x, weight, bias, activation=activation)
-        dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
-        dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
-        # if not dw_equal:
-        #     breakpoint()
-        if has_bias:
-            db_equal = torch.allclose(db, db0, atol=db_atol)
-            # if not db_equal:
-            #     breakpoint()
-        assert torch.equal(out, out0)
-        assert torch.equal(dx, dx0)
-        assert dw_equal
-        if has_bias:
-            assert dw_equal
diff --git a/causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl b/causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl
new file mode 100644
index 0000000000000000000000000000000000000000..3c346180501b1f8dc77018794c5676256e5043a3
--- /dev/null
+++ b/causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:78328bff9f0cf4814aa3c4029d63aa75128694e07ddae688b16215e3d8a2e7e7
+size 8424758
diff --git a/install.sh b/install.sh
deleted file mode 100644
index a0473cd3342036d5a523de17668cdc1d39bd4521..0000000000000000000000000000000000000000
--- a/install.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118
-pip install -e causal-conv1d
-pip install -e mamba
\ No newline at end of file
diff --git a/mamba/.gitmodules b/mamba/.gitmodules
deleted file mode 100644
index a7445800fb64f3ae664c0b994a54235105986d2e..0000000000000000000000000000000000000000
--- a/mamba/.gitmodules
+++ /dev/null
@@ -1,3 +0,0 @@
-[submodule "3rdparty/lm-evaluation-harness"]
-	path = 3rdparty/lm-evaluation-harness
-	url = https://github.com/EleutherAI/lm-evaluation-harness/
diff --git a/mamba/AUTHORS b/mamba/AUTHORS
deleted file mode 100644
index 38557a872f8d603ed963a05c211de7032de5926b..0000000000000000000000000000000000000000
--- a/mamba/AUTHORS
+++ /dev/null
@@ -1,2 +0,0 @@
-Tri Dao, tri@tridao.me
-Albert Gu, agu@andrew.cmu.edu
diff --git a/mamba/LICENSE b/mamba/LICENSE
deleted file mode 100644
index f4abe24eb520fbb077753ae4f34bfaa43cb3b83f..0000000000000000000000000000000000000000
--- a/mamba/LICENSE
+++ /dev/null
@@ -1,201 +0,0 @@
-                                 Apache License
-                           Version 2.0, January 2004
-                        http://www.apache.org/licenses/
-
-   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
-   1. Definitions.
-
-      "License" shall mean the terms and conditions for use, reproduction,
-      and distribution as defined by Sections 1 through 9 of this document.
-
-      "Licensor" shall mean the copyright owner or entity authorized by
-      the copyright owner that is granting the License.
-
-      "Legal Entity" shall mean the union of the acting entity and all
-      other entities that control, are controlled by, or are under common
-      control with that entity. For the purposes of this definition,
-      "control" means (i) the power, direct or indirect, to cause the
-      direction or management of such entity, whether by contract or
-      otherwise, or (ii) ownership of fifty percent (50%) or more of the
-      outstanding shares, or (iii) beneficial ownership of such entity.
-
-      "You" (or "Your") shall mean an individual or Legal Entity
-      exercising permissions granted by this License.
-
-      "Source" form shall mean the preferred form for making modifications,
-      including but not limited to software source code, documentation
-      source, and configuration files.
-
-      "Object" form shall mean any form resulting from mechanical
-      transformation or translation of a Source form, including but
-      not limited to compiled object code, generated documentation,
-      and conversions to other media types.
-
-      "Work" shall mean the work of authorship, whether in Source or
-      Object form, made available under the License, as indicated by a
-      copyright notice that is included in or attached to the work
-      (an example is provided in the Appendix below).
-
-      "Derivative Works" shall mean any work, whether in Source or Object
-      form, that is based on (or derived from) the Work and for which the
-      editorial revisions, annotations, elaborations, or other modifications
-      represent, as a whole, an original work of authorship. For the purposes
-      of this License, Derivative Works shall not include works that remain
-      separable from, or merely link (or bind by name) to the interfaces of,
-      the Work and Derivative Works thereof.
-
-      "Contribution" shall mean any work of authorship, including
-      the original version of the Work and any modifications or additions
-      to that Work or Derivative Works thereof, that is intentionally
-      submitted to Licensor for inclusion in the Work by the copyright owner
-      or by an individual or Legal Entity authorized to submit on behalf of
-      the copyright owner. For the purposes of this definition, "submitted"
-      means any form of electronic, verbal, or written communication sent
-      to the Licensor or its representatives, including but not limited to
-      communication on electronic mailing lists, source code control systems,
-      and issue tracking systems that are managed by, or on behalf of, the
-      Licensor for the purpose of discussing and improving the Work, but
-      excluding communication that is conspicuously marked or otherwise
-      designated in writing by the copyright owner as "Not a Contribution."
-
-      "Contributor" shall mean Licensor and any individual or Legal Entity
-      on behalf of whom a Contribution has been received by Licensor and
-      subsequently incorporated within the Work.
-
-   2. Grant of Copyright License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      copyright license to reproduce, prepare Derivative Works of,
-      publicly display, publicly perform, sublicense, and distribute the
-      Work and such Derivative Works in Source or Object form.
-
-   3. Grant of Patent License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      (except as stated in this section) patent license to make, have made,
-      use, offer to sell, sell, import, and otherwise transfer the Work,
-      where such license applies only to those patent claims licensable
-      by such Contributor that are necessarily infringed by their
-      Contribution(s) alone or by combination of their Contribution(s)
-      with the Work to which such Contribution(s) was submitted. If You
-      institute patent litigation against any entity (including a
-      cross-claim or counterclaim in a lawsuit) alleging that the Work
-      or a Contribution incorporated within the Work constitutes direct
-      or contributory patent infringement, then any patent licenses
-      granted to You under this License for that Work shall terminate
-      as of the date such litigation is filed.
-
-   4. Redistribution. You may reproduce and distribute copies of the
-      Work or Derivative Works thereof in any medium, with or without
-      modifications, and in Source or Object form, provided that You
-      meet the following conditions:
-
-      (a) You must give any other recipients of the Work or
-          Derivative Works a copy of this License; and
-
-      (b) You must cause any modified files to carry prominent notices
-          stating that You changed the files; and
-
-      (c) You must retain, in the Source form of any Derivative Works
-          that You distribute, all copyright, patent, trademark, and
-          attribution notices from the Source form of the Work,
-          excluding those notices that do not pertain to any part of
-          the Derivative Works; and
-
-      (d) If the Work includes a "NOTICE" text file as part of its
-          distribution, then any Derivative Works that You distribute must
-          include a readable copy of the attribution notices contained
-          within such NOTICE file, excluding those notices that do not
-          pertain to any part of the Derivative Works, in at least one
-          of the following places: within a NOTICE text file distributed
-          as part of the Derivative Works; within the Source form or
-          documentation, if provided along with the Derivative Works; or,
-          within a display generated by the Derivative Works, if and
-          wherever such third-party notices normally appear. The contents
-          of the NOTICE file are for informational purposes only and
-          do not modify the License. You may add Your own attribution
-          notices within Derivative Works that You distribute, alongside
-          or as an addendum to the NOTICE text from the Work, provided
-          that such additional attribution notices cannot be construed
-          as modifying the License.
-
-      You may add Your own copyright statement to Your modifications and
-      may provide additional or different license terms and conditions
-      for use, reproduction, or distribution of Your modifications, or
-      for any such Derivative Works as a whole, provided Your use,
-      reproduction, and distribution of the Work otherwise complies with
-      the conditions stated in this License.
-
-   5. Submission of Contributions. Unless You explicitly state otherwise,
-      any Contribution intentionally submitted for inclusion in the Work
-      by You to the Licensor shall be under the terms and conditions of
-      this License, without any additional terms or conditions.
-      Notwithstanding the above, nothing herein shall supersede or modify
-      the terms of any separate license agreement you may have executed
-      with Licensor regarding such Contributions.
-
-   6. Trademarks. This License does not grant permission to use the trade
-      names, trademarks, service marks, or product names of the Licensor,
-      except as required for reasonable and customary use in describing the
-      origin of the Work and reproducing the content of the NOTICE file.
-
-   7. Disclaimer of Warranty. Unless required by applicable law or
-      agreed to in writing, Licensor provides the Work (and each
-      Contributor provides its Contributions) on an "AS IS" BASIS,
-      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-      implied, including, without limitation, any warranties or conditions
-      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
-      PARTICULAR PURPOSE. You are solely responsible for determining the
-      appropriateness of using or redistributing the Work and assume any
-      risks associated with Your exercise of permissions under this License.
-
-   8. Limitation of Liability. In no event and under no legal theory,
-      whether in tort (including negligence), contract, or otherwise,
-      unless required by applicable law (such as deliberate and grossly
-      negligent acts) or agreed to in writing, shall any Contributor be
-      liable to You for damages, including any direct, indirect, special,
-      incidental, or consequential damages of any character arising as a
-      result of this License or out of the use or inability to use the
-      Work (including but not limited to damages for loss of goodwill,
-      work stoppage, computer failure or malfunction, or any and all
-      other commercial damages or losses), even if such Contributor
-      has been advised of the possibility of such damages.
-
-   9. Accepting Warranty or Additional Liability. While redistributing
-      the Work or Derivative Works thereof, You may choose to offer,
-      and charge a fee for, acceptance of support, warranty, indemnity,
-      or other liability obligations and/or rights consistent with this
-      License. However, in accepting such obligations, You may act only
-      on Your own behalf and on Your sole responsibility, not on behalf
-      of any other Contributor, and only if You agree to indemnify,
-      defend, and hold each Contributor harmless for any liability
-      incurred by, or claims asserted against, such Contributor by reason
-      of your accepting any such warranty or additional liability.
-
-   END OF TERMS AND CONDITIONS
-
-   APPENDIX: How to apply the Apache License to your work.
-
-      To apply the Apache License to your work, attach the following
-      boilerplate notice, with the fields enclosed by brackets "[]"
-      replaced with your own identifying information. (Don't include
-      the brackets!)  The text should be enclosed in the appropriate
-      comment syntax for the file format. We also recommend that a
-      file or class name and description of purpose be included on the
-      same "printed page" as the copyright notice for easier
-      identification within third-party archives.
-
-   Copyright 2023 Tri Dao, Albert Gu
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-       http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License.
diff --git a/mamba/README.md b/mamba/README.md
deleted file mode 100644
index 754cefd7f862a90bad8fbdff71e3793a4e7849e3..0000000000000000000000000000000000000000
--- a/mamba/README.md
+++ /dev/null
@@ -1,149 +0,0 @@
-# Mamba
-
-![Mamba](assets/selection.png "Selective State Space")
-> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
-> Albert Gu*, Tri Dao*\
-> Paper: https://arxiv.org/abs/2312.00752
-
-## About
-
-Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
-It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
-with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
-
-## Installation
-
-- `pip install causal-conv1d`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
-- `pip install mamba-ssm`: the core Mamba package.
-
-It can also be built from source with `pip install .` from this repository.
-
-If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
-
-Other requirements:
-- Linux
-- NVIDIA GPU
-- PyTorch 1.12+
-- CUDA 11.6+
-
-## Usage
-
-We expose several levels of interface with the Mamba model.
-
-### Selective SSM
-
-Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
-
-Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
-
-### Mamba Block
-
-The main module of this repository is the Mamba architecture block wrapping the selective SSM.
-
-Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
-
-Usage:
-```
-from mamba_ssm import Mamba
-
-batch, length, dim = 2, 64, 16
-x = torch.randn(batch, length, dim).to("cuda")
-model = Mamba(
-    # This module uses roughly 3 * expand * d_model^2 parameters
-    d_model=dim, # Model dimension d_model
-    d_state=16,  # SSM state expansion factor
-    d_conv=4,    # Local convolution width
-    expand=2,    # Block expansion factor
-).to("cuda")
-y = model(x)
-assert y.shape == x.shape
-```
-
-### Mamba Language Model
-
-Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
-
-Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
-
-This is an example of how to integrate Mamba into an end-to-end neural network.
-This example is used in the generation scripts below.
-
-
-
-## Pretrained Models
-
-Pretrained models are uploaded to
-[HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
-`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`.
-
-The models will be autodownloaded by the generation script below.
-
-These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
-
-| Parameters | Layers | Model dim. | 
-|------------|--------|------------|
-| 130M       | 12     | 768        |
-| 370M       | 24     | 1024       |
-| 790M       | 24     | 1536       |
-| 1.4B       | 24     | 2048       |
-| 2.8B       | 32     | 2560       |
-
-(The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
-
-Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
-Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
-
-
-## Evaluations
-
-To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
-we use the
-[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
-library.
-
-1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
-   --recursive`. We use the `big-refactor` branch.
-2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`
-3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
-```
-python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
-python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
-```
-
-Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
-
-## Inference
-
-The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
-1. autoloads a model from the HuggingFace Hub,
-2. generates completions of a user-specified prompt,
-3. benchmarks the inference speed of this generation.
-
-Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
-
-### Examples
-
-To test generation latency (e.g. batch size = 1) with different sampling strategies:
-
-```
-python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
-python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
-```
-
-To test generation throughput with random prompts (e.g. large batch size):
-```
-python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
-python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
-```
-
-## Citation
-
-If you use this codebase, or otherwise found our work valuable, please cite Mamba:
-```
-@article{mamba,
-  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
-  author={Gu, Albert and Dao, Tri},
-  journal={arXiv preprint arXiv:2312.00752},
-  year={2023}
-}
-```
diff --git a/mamba/assets/selection.png b/mamba/assets/selection.png
deleted file mode 100644
index 69b109a8eed4e3c7516b23e2b39d37e842a4464b..0000000000000000000000000000000000000000
Binary files a/mamba/assets/selection.png and /dev/null differ
diff --git a/mamba/benchmarks/benchmark_generation_mamba_simple.py b/mamba/benchmarks/benchmark_generation_mamba_simple.py
deleted file mode 100644
index 8f2943cb4bde6f25eddb82b7b999c5c5f8b39acc..0000000000000000000000000000000000000000
--- a/mamba/benchmarks/benchmark_generation_mamba_simple.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Copyright (c) 2023, Tri Dao, Albert Gu.
-
-import argparse
-import time
-import json
-
-import torch
-import torch.nn.functional as F
-
-from einops import rearrange
-
-from transformers import AutoTokenizer, AutoModelForCausalLM
-
-from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
-
-
-parser = argparse.ArgumentParser(description="Generation benchmarking")
-parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
-parser.add_argument("--prompt", type=str, default=None)
-parser.add_argument("--promptlen", type=int, default=100)
-parser.add_argument("--genlen", type=int, default=100)
-parser.add_argument("--temperature", type=float, default=1.0)
-parser.add_argument("--topk", type=int, default=1)
-parser.add_argument("--topp", type=float, default=1.0)
-parser.add_argument("--batch", type=int, default=1)
-args = parser.parse_args()
-
-repeats = 3
-device = "cuda"
-dtype = torch.float16
-
-print(f"Loading model {args.model_name}")
-is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name
-
-if is_mamba:
-    tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer")
-    model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
-else:
-    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
-    model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
-model.eval()
-print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
-
-torch.random.manual_seed(0)
-if args.prompt is None:
-    input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
-    attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
-else:
-    tokens = tokenizer(args.prompt, return_tensors="pt")
-    input_ids = tokens.input_ids.to(device=device)
-    attn_mask = tokens.attention_mask.to(device=device)
-max_length = input_ids.shape[1] + args.genlen
-
-if is_mamba:
-    fn = lambda: model.generate(
-        input_ids=input_ids,
-        max_length=max_length,
-        cg=True,
-        return_dict_in_generate=True,
-        output_scores=True,
-        enable_timing=False,
-        temperature=args.temperature,
-        top_k=args.topk,
-        top_p=args.topp,
-    )
-else:
-    fn = lambda: model.generate(
-        input_ids=input_ids,
-        attention_mask=attn_mask,
-        max_length=max_length,
-        return_dict_in_generate=True,
-        pad_token_id=tokenizer.eos_token_id,
-        do_sample=True,
-        temperature=args.temperature,
-        top_k=args.topk,
-        top_p=args.topp,
-    )
-out = fn()
-if args.prompt is not None:
-    print(tokenizer.batch_decode(out.sequences.tolist()))
-
-torch.cuda.synchronize()
-start = time.time()
-for _ in range(repeats):
-    fn()
-torch.cuda.synchronize()
-print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
-print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
diff --git a/mamba/csrc/selective_scan/reverse_scan.cuh b/mamba/csrc/selective_scan/reverse_scan.cuh
deleted file mode 100644
index d7e93174bb391d45271e6c77669a5e52d6c9cc78..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/reverse_scan.cuh
+++ /dev/null
@@ -1,401 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#pragma once
-
-#include <cub/config.cuh>
-
-#include <cub/util_ptx.cuh>
-#include <cub/util_type.cuh>
-#include <cub/block/block_raking_layout.cuh>
-// #include <cub/detail/uninitialized_copy.cuh>
-#include "uninitialized_copy.cuh"
-
-/**
- * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array.  The aggregate is returned.
- */
-template <
-    int         LENGTH,
-    typename    T,
-    typename    ReductionOp>
-__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
-    static_assert(LENGTH > 0);
-    T retval = input[LENGTH - 1];
-    #pragma unroll
-    for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
-    return retval;
-}
-
-/**
- * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix.  The aggregate is returned.
- */
-template <
-    int         LENGTH,
-    typename    T,
-    typename    ScanOp>
-__device__ __forceinline__ T ThreadReverseScanInclusive(
-    const T (&input)[LENGTH],
-    T (&output)[LENGTH],
-    ScanOp scan_op,
-    const T postfix)
-{
-    T inclusive = postfix;
-    #pragma unroll
-    for (int i = LENGTH - 1; i >= 0; --i) {
-        inclusive = scan_op(inclusive, input[i]);
-        output[i] = inclusive;
-    }
-}
-
-/**
- * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix.  The aggregate is returned.
- */
-template <
-    int         LENGTH,
-    typename    T,
-    typename    ScanOp>
-__device__ __forceinline__ T ThreadReverseScanExclusive(
-    const T (&input)[LENGTH],
-    T (&output)[LENGTH],
-    ScanOp scan_op,
-    const T postfix)
-{
-    // Careful, output maybe be aliased to input
-    T exclusive = postfix;
-    T inclusive;
-    #pragma unroll
-    for (int i = LENGTH - 1; i >= 0; --i) {
-        inclusive = scan_op(exclusive, input[i]);
-        output[i] = exclusive;
-        exclusive = inclusive;
-    }
-    return inclusive;
-}
-
-
-/**
- * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
- *
- * LOGICAL_WARP_THREADS must be a power-of-two
- */
-template <
-    typename    T,                      ///< Data type being scanned
-    int         LOGICAL_WARP_THREADS    ///< Number of threads per logical warp
-    >
-struct WarpReverseScan {
-    //---------------------------------------------------------------------
-    // Constants and type definitions
-    //---------------------------------------------------------------------
-
-    /// Whether the logical warp size and the PTX warp size coincide
-    static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));
-    /// The number of warp scan steps
-    static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
-    static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
-
-
-    //---------------------------------------------------------------------
-    // Thread fields
-    //---------------------------------------------------------------------
-
-    /// Lane index in logical warp
-    unsigned int lane_id;
-
-    /// Logical warp index in 32-thread physical warp
-    unsigned int warp_id;
-
-    /// 32-thread physical warp member mask of logical warp
-    unsigned int member_mask;
-
-    //---------------------------------------------------------------------
-    // Construction
-    //---------------------------------------------------------------------
-
-    /// Constructor
-    explicit __device__ __forceinline__
-    WarpReverseScan()
-        : lane_id(cub::LaneId())
-        , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
-        , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
-    {
-        if (!IS_ARCH_WARP) {
-            lane_id = lane_id % LOGICAL_WARP_THREADS;
-        }
-    }
-
-
-    /// Broadcast
-    __device__ __forceinline__ T Broadcast(
-        T               input,              ///< [in] The value to broadcast
-        int             src_lane)           ///< [in] Which warp lane is to do the broadcasting
-    {
-        return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
-    }
-
-
-    /// Inclusive scan
-    template <typename ScanOpT>
-    __device__ __forceinline__ void InclusiveReverseScan(
-        T               input,              ///< [in] Calling thread's input item.
-        T               &inclusive_output,  ///< [out] Calling thread's output item.  May be aliased with \p input.
-        ScanOpT         scan_op)            ///< [in] Binary scan operator
-    {
-        inclusive_output = input;
-        #pragma unroll
-        for (int STEP = 0; STEP < STEPS; STEP++) {
-            int offset = 1 << STEP;
-            T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
-                inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
-            );
-            // Perform scan op if from a valid peer
-            inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
-                ? inclusive_output : scan_op(temp, inclusive_output);
-        }
-    }
-
-    /// Exclusive scan
-    // Get exclusive from inclusive
-    template <typename ScanOpT>
-    __device__ __forceinline__ void ExclusiveReverseScan(
-        T              input,              ///< [in] Calling thread's input item.
-        T              &exclusive_output,  ///< [out] Calling thread's output item.  May be aliased with \p input.
-        ScanOpT        scan_op,            ///< [in] Binary scan operator
-        T              &warp_aggregate)    ///< [out] Warp-wide aggregate reduction of input items.
-    {
-        T inclusive_output;
-        InclusiveReverseScan(input, inclusive_output, scan_op);
-        warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
-        // initial value unknown
-        exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
-            inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
-        );
-    }
-
-    /**
-     * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp.  Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
-     */
-    template <typename ScanOpT>
-    __device__ __forceinline__ void ReverseScan(
-        T               input,              ///< [in] Calling thread's input item.
-        T               &inclusive_output,  ///< [out] Calling thread's inclusive-scan output item.
-        T               &exclusive_output,  ///< [out] Calling thread's exclusive-scan output item.
-        ScanOpT         scan_op)            ///< [in] Binary scan operator
-    {
-        InclusiveReverseScan(input, inclusive_output, scan_op);
-        // initial value unknown
-        exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
-            inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
-        );
-    }
-
-};
-
-/**
- * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
- */
-template <
-    typename    T,              ///< Data type being scanned
-    int         BLOCK_DIM_X,    ///< The thread block length in threads along the X dimension
-    bool        MEMOIZE=false   ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
-    >
-struct BlockReverseScan {
-    //---------------------------------------------------------------------
-    // Types and constants
-    //---------------------------------------------------------------------
-
-    /// Constants
-    /// The thread block size in threads
-    static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
-
-    /// Layout type for padded thread block raking grid
-    using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
-    // The number of reduction elements is not a multiple of the number of raking threads for now
-    static_assert(BlockRakingLayout::UNGUARDED);
-
-    /// Number of raking threads
-    static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
-    /// Number of raking elements per warp synchronous raking thread
-    static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
-    /// Cooperative work can be entirely warp synchronous
-    static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
-
-    ///  WarpReverseScan utility type
-    using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
-
-    /// Shared memory storage layout type
-    struct _TempStorage {
-        typename BlockRakingLayout::TempStorage raking_grid;     ///< Padded thread block raking grid
-    };
-
-
-    /// Alias wrapper allowing storage to be unioned
-    struct TempStorage : cub::Uninitialized<_TempStorage> {};
-
-
-    //---------------------------------------------------------------------
-    // Per-thread fields
-    //---------------------------------------------------------------------
-
-    // Thread fields
-    _TempStorage    &temp_storage;
-    unsigned int    linear_tid;
-    T               cached_segment[SEGMENT_LENGTH];
-
-
-    //---------------------------------------------------------------------
-    // Utility methods
-    //---------------------------------------------------------------------
-
-    /// Performs upsweep raking reduction, returning the aggregate
-    template <typename ScanOp>
-    __device__ __forceinline__ T Upsweep(ScanOp scan_op) {
-        T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
-        // Read data into registers
-        #pragma unroll
-        for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
-        T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
-        #pragma unroll
-        for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
-            raking_partial = scan_op(raking_partial, cached_segment[i]);
-        }
-        return raking_partial;
-    }
-
-
-    /// Performs exclusive downsweep raking scan
-    template <typename ScanOp>
-    __device__ __forceinline__ void ExclusiveDownsweep(
-        ScanOp          scan_op,
-        T               raking_partial)
-    {
-        T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
-        // Read data back into registers
-        if (!MEMOIZE) {
-            #pragma unroll
-            for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
-        }
-        ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
-        // Write data back to smem
-        #pragma unroll
-        for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
-    }
-
-
-    //---------------------------------------------------------------------
-    // Constructors
-    //---------------------------------------------------------------------
-
-    /// Constructor
-    __device__ __forceinline__ BlockReverseScan(
-        TempStorage &temp_storage)
-    :
-        temp_storage(temp_storage.Alias()),
-        linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
-    {}
-
-
-    /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor.  Each thread contributes one input element.  the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs.  Also provides every thread with the block-wide \p block_aggregate of all inputs.
-    template <
-        typename ScanOp,
-        typename BlockPostfixCallbackOp>
-    __device__ __forceinline__ void ExclusiveReverseScan(
-        T                       input,                          ///< [in] Calling thread's input item
-        T                       &exclusive_output,              ///< [out] Calling thread's output item (may be aliased to \p input)
-        ScanOp                  scan_op,                        ///< [in] Binary scan operator
-        BlockPostfixCallbackOp  &block_postfix_callback_op)     ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
-    {
-        if (WARP_SYNCHRONOUS) {
-            // Short-circuit directly to warp-synchronous scan
-            T block_aggregate;
-            WarpReverseScan warp_scan;
-            warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
-            // Obtain warp-wide postfix in lane0, then broadcast to other lanes
-            T block_postfix = block_postfix_callback_op(block_aggregate);
-            block_postfix = warp_scan.Broadcast(block_postfix, 0);
-            exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
-        } else {
-            // Place thread partial into shared memory raking grid
-            T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
-            detail::uninitialized_copy(placement_ptr, input);
-            cub::CTA_SYNC();
-            // Reduce parallelism down to just raking threads
-            if (linear_tid < RAKING_THREADS) {
-                WarpReverseScan warp_scan;
-                // Raking upsweep reduction across shared partials
-                T upsweep_partial = Upsweep(scan_op);
-                // Warp-synchronous scan
-                T exclusive_partial, block_aggregate;
-                warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
-                // Obtain block-wide postfix in lane0, then broadcast to other lanes
-                T block_postfix = block_postfix_callback_op(block_aggregate);
-                block_postfix = warp_scan.Broadcast(block_postfix, 0);
-                // Update postfix with warpscan exclusive partial
-                T downsweep_postfix = linear_tid == RAKING_THREADS - 1
-                    ? block_postfix : scan_op(block_postfix, exclusive_partial);
-                // Exclusive raking downsweep scan
-                ExclusiveDownsweep(scan_op, downsweep_postfix);
-            }
-            cub::CTA_SYNC();
-            // Grab thread postfix from shared memory
-            exclusive_output = *placement_ptr;
-
-            // // Compute warp scan in each warp.
-            // // The exclusive output from the last lane in each warp is invalid.
-            // T inclusive_output;
-            // WarpReverseScan warp_scan;
-            // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
-
-            // // Compute the warp-wide postfix and block-wide aggregate for each warp.  Warp postfix for the last warp is invalid.
-            // T block_aggregate;
-            // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
-
-            // // Apply warp postfix to our lane's partial
-            // if (warp_id != 0) {
-            //     exclusive_output = scan_op(warp_postfix, exclusive_output);
-            //     if (lane_id == 0) { exclusive_output = warp_postfix; }
-            // }
-
-            // // Use the first warp to determine the thread block postfix, returning the result in lane0
-            // if (warp_id == 0) {
-            //     T block_postfix = block_postfix_callback_op(block_aggregate);
-            //     if (lane_id == 0) {
-            //         // Share the postfix with all threads
-            //         detail::uninitialized_copy(&temp_storage.block_postfix,
-            //                                   block_postfix);
-
-            //         exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
-            //     }
-            // }
-
-            // cub::CTA_SYNC();
-
-            // // Incorporate thread block postfix into outputs
-            // T block_postfix = temp_storage.block_postfix;
-            // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
-        }
-    }
-
-
-    /**
-     * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor.  Each thread contributes an array of consecutive input elements.  the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs.  Also provides every thread with the block-wide \p block_aggregate of all inputs.
-     */
-    template <
-        int             ITEMS_PER_THREAD,
-        typename        ScanOp,
-        typename        BlockPostfixCallbackOp>
-    __device__ __forceinline__ void InclusiveReverseScan(
-        T                       (&input)[ITEMS_PER_THREAD],     ///< [in] Calling thread's input items
-        T                       (&output)[ITEMS_PER_THREAD],    ///< [out] Calling thread's output items (may be aliased to \p input)
-        ScanOp                  scan_op,                        ///< [in] Binary scan functor
-        BlockPostfixCallbackOp   &block_postfix_callback_op)    ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
-    {
-        // Reduce consecutive thread items in registers
-        T thread_postfix = ThreadReverseReduce(input, scan_op);
-        // Exclusive thread block-scan
-        ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
-        // Inclusive scan in registers with postfix as seed
-        ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
-    }
-
-};
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan.cpp b/mamba/csrc/selective_scan/selective_scan.cpp
deleted file mode 100644
index f51af402a190dc14247ef8185a7d01b697313f02..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan.cpp
+++ /dev/null
@@ -1,497 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#include <ATen/cuda/CUDAContext.h>
-#include <c10/cuda/CUDAGuard.h>
-#include <torch/extension.h>
-#include <vector>
-
-#include "selective_scan.h"
-
-#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
-
-#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...)                    \
-    if (ITYPE == at::ScalarType::Half) {                                            \
-        using input_t = at::Half;                                                   \
-        __VA_ARGS__();                                                              \
-    } else if (ITYPE == at::ScalarType::BFloat16) {                                 \
-        using input_t = at::BFloat16;                                               \
-        __VA_ARGS__();                                                              \
-    } else if (ITYPE == at::ScalarType::Float)  {                                   \
-        using input_t = float;                                                      \
-        __VA_ARGS__();                                                              \
-    } else {                                                                        \
-        AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
-    }
-
-#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...)                     \
-    if (WTYPE == at::ScalarType::Half) {                                             \
-        using weight_t = at::Half;                                                   \
-        __VA_ARGS__();                                                               \
-    } else if (WTYPE == at::ScalarType::BFloat16) {                                  \
-        using weight_t = at::BFloat16;                                               \
-        __VA_ARGS__();                                                               \
-    } else if (WTYPE == at::ScalarType::Float)  {                                    \
-        using weight_t = float;                                                      \
-        __VA_ARGS__();                                                               \
-    } else {                                                                         \
-        AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
-    }
-
-#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...)                           \
-    if (WTYPE == at::ScalarType::Float) {                                            \
-       using weight_t = float;                                                       \
-        __VA_ARGS__();                                                               \
-    } else if (WTYPE == at::ScalarType::ComplexFloat) {                              \
-        using weight_t = c10::complex<float>;                                        \
-        __VA_ARGS__();                                                               \
-    } else {                                                                         \
-        AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
-    }
-
-template<typename input_t, typename weight_t>
-void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
-
-template <typename input_t, typename weight_t>
-void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
-
-void set_ssm_params_fwd(SSMParamsBase &params,
-                        // sizes
-                        const size_t batch,
-                        const size_t dim,
-                        const size_t seqlen,
-                        const size_t dstate,
-                        const size_t n_groups,
-                        const size_t n_chunks,
-                        const bool is_variable_B,
-                        const bool is_variable_C,
-                        // device pointers
-                        const at::Tensor u,
-                        const at::Tensor delta,
-                        const at::Tensor A,
-                        const at::Tensor B,
-                        const at::Tensor C,
-                        const at::Tensor out,
-                        const at::Tensor z,
-                        const at::Tensor out_z,
-                        void* D_ptr,
-                        void* delta_bias_ptr,
-                        void* x_ptr,
-                        bool has_z,
-                        bool delta_softplus) {
-
-    // Reset the parameters
-    memset(&params, 0, sizeof(params));
-
-    params.batch = batch;
-    params.dim = dim;
-    params.seqlen = seqlen;
-    params.dstate = dstate;
-    params.n_groups = n_groups;
-    params.n_chunks = n_chunks;
-    params.dim_ngroups_ratio = dim / n_groups;
-
-    params.delta_softplus = delta_softplus;
-
-    params.is_variable_B = is_variable_B;
-    params.is_variable_C = is_variable_C;
-
-    // Set the pointers and strides.
-    params.u_ptr = u.data_ptr();
-    params.delta_ptr = delta.data_ptr();
-    params.A_ptr = A.data_ptr();
-    params.B_ptr = B.data_ptr();
-    params.C_ptr = C.data_ptr();
-    params.D_ptr = D_ptr;
-    params.delta_bias_ptr = delta_bias_ptr;
-    params.out_ptr = out.data_ptr();
-    params.x_ptr = x_ptr;
-    params.z_ptr = has_z ? z.data_ptr() : nullptr;
-    params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
-    // All stride are in elements, not bytes.
-    params.A_d_stride = A.stride(0);
-    params.A_dstate_stride = A.stride(1);
-    if (!is_variable_B) {
-        params.B_d_stride = B.stride(0);
-    } else {
-        params.B_batch_stride = B.stride(0);
-        params.B_group_stride = B.stride(1);
-    }
-    params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
-    if (!is_variable_C) {
-        params.C_d_stride = C.stride(0);
-    } else {
-        params.C_batch_stride = C.stride(0);
-        params.C_group_stride = C.stride(1);
-    }
-    params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
-    params.u_batch_stride = u.stride(0);
-    params.u_d_stride = u.stride(1);
-    params.delta_batch_stride = delta.stride(0);
-    params.delta_d_stride = delta.stride(1);
-    if (has_z) {
-        params.z_batch_stride = z.stride(0);
-        params.z_d_stride = z.stride(1);
-        params.out_z_batch_stride = out_z.stride(0);
-        params.out_z_d_stride = out_z.stride(1);
-    }
-    params.out_batch_stride = out.stride(0);
-    params.out_d_stride = out.stride(1);
-}
-
-void set_ssm_params_bwd(SSMParamsBwd &params,
-                        // sizes
-                        const size_t batch,
-                        const size_t dim,
-                        const size_t seqlen,
-                        const size_t dstate,
-                        const size_t n_groups,
-                        const size_t n_chunks,
-                        const bool is_variable_B,
-                        const bool is_variable_C,
-                        // device pointers
-                        const at::Tensor u,
-                        const at::Tensor delta,
-                        const at::Tensor A,
-                        const at::Tensor B,
-                        const at::Tensor C,
-                        const at::Tensor z,
-                        const at::Tensor out,
-                        const at::Tensor out_z,
-                        void* D_ptr,
-                        void* delta_bias_ptr,
-                        void* x_ptr,
-                        const at::Tensor dout,
-                        const at::Tensor du,
-                        const at::Tensor ddelta,
-                        const at::Tensor dA,
-                        const at::Tensor dB,
-                        const at::Tensor dC,
-                        const at::Tensor dz,
-                        void* dD_ptr,
-                        void* ddelta_bias_ptr,
-                        bool has_z,
-                        bool delta_softplus,
-                        bool recompute_out_z) {
-    // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
-    set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
-                       u, delta, A, B, C, has_z ? out : dout,
-                       has_z ? z : dout,
-                       // If not recompute_out_z, pass dout instead of out_z.
-                       // This won't be used by the bwd kernel
-                       recompute_out_z ? out_z : dout,
-                       D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
-    if (!recompute_out_z) { params.out_z_ptr = nullptr; }
-
-    // Set the pointers and strides.
-    params.dout_ptr = dout.data_ptr();
-    params.du_ptr = du.data_ptr();
-    params.dA_ptr = dA.data_ptr();
-    params.dB_ptr = dB.data_ptr();
-    params.dC_ptr = dC.data_ptr();
-    params.dD_ptr = dD_ptr;
-    params.ddelta_ptr = ddelta.data_ptr();
-    params.ddelta_bias_ptr = ddelta_bias_ptr;
-    params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
-    // All stride are in elements, not bytes.
-    params.dout_batch_stride = dout.stride(0);
-    params.dout_d_stride = dout.stride(1);
-    params.dA_d_stride = dA.stride(0);
-    params.dA_dstate_stride = dA.stride(1);
-    if (!is_variable_B) {
-        params.dB_d_stride = dB.stride(0);
-    } else {
-        params.dB_batch_stride = dB.stride(0);
-        params.dB_group_stride = dB.stride(1);
-    }
-    params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
-    if (!is_variable_C) {
-        params.dC_d_stride = dC.stride(0);
-    } else {
-        params.dC_batch_stride = dC.stride(0);
-        params.dC_group_stride = dC.stride(1);
-    }
-    params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
-    params.du_batch_stride = du.stride(0);
-    params.du_d_stride = du.stride(1);
-    params.ddelta_batch_stride = ddelta.stride(0);
-    params.ddelta_d_stride = ddelta.stride(1);
-    if (has_z) {
-        params.dz_batch_stride = dz.stride(0);
-        params.dz_d_stride = dz.stride(1);
-    }
-}
-
-std::vector<at::Tensor>
-selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
-                  const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
-                  const c10::optional<at::Tensor> &D_,
-                  const c10::optional<at::Tensor> &z_,
-                  const c10::optional<at::Tensor> &delta_bias_,
-                  bool delta_softplus) {
-    auto input_type = u.scalar_type();
-    auto weight_type = A.scalar_type();
-    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
-    TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
-
-    const bool is_variable_B = B.dim() >= 3;
-    const bool is_variable_C = C.dim() >= 3;
-    const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
-
-    TORCH_CHECK(delta.scalar_type() == input_type);
-    TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
-    TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
-
-    TORCH_CHECK(u.is_cuda());
-    TORCH_CHECK(delta.is_cuda());
-    TORCH_CHECK(A.is_cuda());
-    TORCH_CHECK(B.is_cuda());
-    TORCH_CHECK(C.is_cuda());
-
-    TORCH_CHECK(u.stride(-1) == 1);
-    TORCH_CHECK(delta.stride(-1) == 1);
-
-    const auto sizes = u.sizes();
-    const int batch_size = sizes[0];
-    const int dim = sizes[1];
-    const int seqlen = sizes[2];
-    const int dstate = A.size(1);
-    const int n_groups = is_variable_B ? B.size(1) : 1;
-
-    TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
-
-    CHECK_SHAPE(u, batch_size, dim, seqlen);
-    CHECK_SHAPE(delta, batch_size, dim, seqlen);
-    CHECK_SHAPE(A, dim, dstate);
-    if (!is_variable_B) {
-        CHECK_SHAPE(B, dim, dstate);
-    } else {
-        CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
-        TORCH_CHECK(B.stride(-1) == 1);
-    }
-    if (!is_variable_C) {
-        CHECK_SHAPE(C, dim, dstate);
-    } else {
-        CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
-        TORCH_CHECK(C.stride(-1) == 1);
-    }
-
-    if (D_.has_value()) {
-        auto D = D_.value();
-        TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
-        TORCH_CHECK(D.is_cuda());
-        TORCH_CHECK(D.stride(-1) == 1);
-        CHECK_SHAPE(D, dim);
-    }
-
-    if (delta_bias_.has_value()) {
-        auto delta_bias = delta_bias_.value();
-        TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
-        TORCH_CHECK(delta_bias.is_cuda());
-        TORCH_CHECK(delta_bias.stride(-1) == 1);
-        CHECK_SHAPE(delta_bias, dim);
-    }
-
-    at::Tensor z, out_z;
-    const bool has_z = z_.has_value();
-    if (has_z) {
-        z = z_.value();
-        TORCH_CHECK(z.scalar_type() == input_type);
-        TORCH_CHECK(z.is_cuda());
-        TORCH_CHECK(z.stride(-1) == 1);
-        CHECK_SHAPE(z, batch_size, dim, seqlen);
-        out_z = torch::empty_like(z);
-    }
-
-    const int n_chunks = (seqlen + 2048 - 1) / 2048;
-    // const int n_chunks = (seqlen + 1024 - 1) / 1024;
-    // at::Tensor out = torch::empty_like(u);
-    // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
-    at::Tensor out = torch::empty_like(delta);
-    at::Tensor x;
-    x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
-
-    SSMParamsBase params;
-    set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
-                       u, delta, A, B, C, out, z, out_z,
-                       D_.has_value() ? D_.value().data_ptr() : nullptr,
-                       delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
-                       x.data_ptr(),
-                       has_z,
-                       delta_softplus);
-
-    // Otherwise the kernel will be launched from cuda:0 device
-    // Cast to char to avoid compiler warning about narrowing
-    at::cuda::CUDAGuard device_guard{(char)u.get_device()};
-    auto stream = at::cuda::getCurrentCUDAStream().stream();
-    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
-        DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
-            selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
-        });
-    });
-    std::vector<at::Tensor> result = {out, x};
-    if (has_z) { result.push_back(out_z); }
-    return result;
-}
-
-std::vector<at::Tensor>
-selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
-                  const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
-                  const c10::optional<at::Tensor> &D_,
-                  const c10::optional<at::Tensor> &z_,
-                  const c10::optional<at::Tensor> &delta_bias_,
-                  const at::Tensor &dout,
-                  const c10::optional<at::Tensor> &x_,
-                  const c10::optional<at::Tensor> &out_,
-                  c10::optional<at::Tensor> &dz_,
-                  bool delta_softplus,
-                  bool recompute_out_z) {
-    auto input_type = u.scalar_type();
-    auto weight_type = A.scalar_type();
-    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
-    TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
-
-    const bool is_variable_B = B.dim() >= 3;
-    const bool is_variable_C = C.dim() >= 3;
-    const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
-
-    TORCH_CHECK(delta.scalar_type() == input_type);
-    TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
-    TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
-    TORCH_CHECK(dout.scalar_type() == input_type);
-
-    TORCH_CHECK(u.is_cuda());
-    TORCH_CHECK(delta.is_cuda());
-    TORCH_CHECK(A.is_cuda());
-    TORCH_CHECK(B.is_cuda());
-    TORCH_CHECK(C.is_cuda());
-    TORCH_CHECK(dout.is_cuda());
-
-    TORCH_CHECK(u.stride(-1) == 1);
-    TORCH_CHECK(delta.stride(-1) == 1);
-    TORCH_CHECK(dout.stride(-1) == 1);
-
-    const auto sizes = u.sizes();
-    const int batch_size = sizes[0];
-    const int dim = sizes[1];
-    const int seqlen = sizes[2];
-    const int dstate = A.size(1);
-    const int n_groups = is_variable_B ? B.size(1) : 1;
-
-    TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
-
-    CHECK_SHAPE(u, batch_size, dim, seqlen);
-    CHECK_SHAPE(delta, batch_size, dim, seqlen);
-    CHECK_SHAPE(A, dim, dstate);
-    if (!is_variable_B) {
-        CHECK_SHAPE(B, dim, dstate);
-    } else {
-        CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
-        TORCH_CHECK(B.stride(-1) == 1);
-    }
-    if (!is_variable_C) {
-        CHECK_SHAPE(C, dim, dstate);
-    } else {
-        CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
-        TORCH_CHECK(C.stride(-1) == 1);
-    }
-    CHECK_SHAPE(dout, batch_size, dim, seqlen);
-
-    if (D_.has_value()) {
-        auto D = D_.value();
-        TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
-        TORCH_CHECK(D.is_cuda());
-        TORCH_CHECK(D.stride(-1) == 1);
-        CHECK_SHAPE(D, dim);
-    }
-
-    if (delta_bias_.has_value()) {
-        auto delta_bias = delta_bias_.value();
-        TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
-        TORCH_CHECK(delta_bias.is_cuda());
-        TORCH_CHECK(delta_bias.stride(-1) == 1);
-        CHECK_SHAPE(delta_bias, dim);
-    }
-
-    at::Tensor z, out, dz, out_z;
-    const bool has_z = z_.has_value();
-    if (has_z) {
-        z = z_.value();
-        TORCH_CHECK(z.scalar_type() == input_type);
-        TORCH_CHECK(z.is_cuda());
-        TORCH_CHECK(z.stride(-1) == 1);
-        CHECK_SHAPE(z, batch_size, dim, seqlen);
-
-        TORCH_CHECK(out_.has_value());
-        out = out_.value();
-        TORCH_CHECK(out.scalar_type() == input_type);
-        TORCH_CHECK(out.is_cuda());
-        TORCH_CHECK(out.stride(-1) == 1);
-        CHECK_SHAPE(out, batch_size, dim, seqlen);
-
-        if (dz_.has_value()) {
-            dz = dz_.value();
-            TORCH_CHECK(dz.scalar_type() == input_type);
-            TORCH_CHECK(dz.is_cuda());
-            TORCH_CHECK(dz.stride(-1) == 1);
-            CHECK_SHAPE(dz, batch_size, dim, seqlen);
-        } else {
-            dz = torch::empty_like(z);
-        }
-        if (recompute_out_z) {
-            out_z = torch::empty_like(out);
-        }
-    }
-
-    const int n_chunks = (seqlen + 2048 - 1) / 2048;
-    // const int n_chunks = (seqlen + 1024 - 1) / 1024;
-    if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
-    if (x_.has_value()) {
-        auto x = x_.value();
-        TORCH_CHECK(x.scalar_type() == weight_type);
-        TORCH_CHECK(x.is_cuda());
-        TORCH_CHECK(x.is_contiguous());
-        CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
-    }
-
-    at::Tensor du = torch::empty_like(u);
-    at::Tensor ddelta = torch::empty_like(delta);
-    at::Tensor dA = torch::zeros_like(A);
-    at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
-    at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
-    at::Tensor dD;
-    if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
-    at::Tensor ddelta_bias;
-    if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
-
-    SSMParamsBwd params;
-    set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
-                       u, delta, A, B, C, z, out, out_z,
-                       D_.has_value() ? D_.value().data_ptr() : nullptr,
-                       delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
-                       x_.has_value() ? x_.value().data_ptr() : nullptr,
-                       dout, du, ddelta, dA, dB, dC, dz,
-                       D_.has_value() ? dD.data_ptr() : nullptr,
-                       delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
-                       has_z, delta_softplus, recompute_out_z);
-
-    // Otherwise the kernel will be launched from cuda:0 device
-    // Cast to char to avoid compiler warning about narrowing
-    at::cuda::CUDAGuard device_guard{(char)u.get_device()};
-    auto stream = at::cuda::getCurrentCUDAStream().stream();
-    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
-        DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
-            selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
-        });
-    });
-    std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
-    if (has_z) { result.push_back(dz); }
-    if (recompute_out_z) { result.push_back(out_z); }
-    return result;
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-    m.def("fwd", &selective_scan_fwd, "Selective scan forward");
-    m.def("bwd", &selective_scan_bwd, "Selective scan backward");
-}
diff --git a/mamba/csrc/selective_scan/selective_scan.h b/mamba/csrc/selective_scan/selective_scan.h
deleted file mode 100644
index e2c7bcdbd5ddadc5975caa641ecb1dcd3b73dafd..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan.h
+++ /dev/null
@@ -1,101 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#pragma once
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-struct SSMScanParamsBase {
-    using index_t = uint32_t;
-
-    int batch, seqlen, n_chunks;
-    index_t a_batch_stride;
-    index_t b_batch_stride;
-    index_t out_batch_stride;
-
-    // Common data pointers.
-    void *__restrict__ a_ptr;
-    void *__restrict__ b_ptr;
-    void *__restrict__ out_ptr;
-    void *__restrict__ x_ptr;
-};
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-struct SSMParamsBase {
-    using index_t = uint32_t;
-
-    int batch, dim, seqlen, dstate, n_groups, n_chunks;
-    int dim_ngroups_ratio;
-    bool is_variable_B;
-    bool is_variable_C;
-
-    bool delta_softplus;
-
-    index_t A_d_stride;
-    index_t A_dstate_stride;
-    index_t B_batch_stride;
-    index_t B_d_stride;
-    index_t B_dstate_stride;
-    index_t B_group_stride;
-    index_t C_batch_stride;
-    index_t C_d_stride;
-    index_t C_dstate_stride;
-    index_t C_group_stride;
-    index_t u_batch_stride;
-    index_t u_d_stride;
-    index_t delta_batch_stride;
-    index_t delta_d_stride;
-    index_t z_batch_stride;
-    index_t z_d_stride;
-    index_t out_batch_stride;
-    index_t out_d_stride;
-    index_t out_z_batch_stride;
-    index_t out_z_d_stride;
-
-    // Common data pointers.
-    void *__restrict__ A_ptr;
-    void *__restrict__ B_ptr;
-    void *__restrict__ C_ptr;
-    void *__restrict__ D_ptr;
-    void *__restrict__ u_ptr;
-    void *__restrict__ delta_ptr;
-    void *__restrict__ delta_bias_ptr;
-    void *__restrict__ out_ptr;
-    void *__restrict__ x_ptr;
-    void *__restrict__ z_ptr;
-    void *__restrict__ out_z_ptr;
-};
-
-struct SSMParamsBwd: public SSMParamsBase {
-    index_t dout_batch_stride;
-    index_t dout_d_stride;
-    index_t dA_d_stride;
-    index_t dA_dstate_stride;
-    index_t dB_batch_stride;
-    index_t dB_group_stride;
-    index_t dB_d_stride;
-    index_t dB_dstate_stride;
-    index_t dC_batch_stride;
-    index_t dC_group_stride;
-    index_t dC_d_stride;
-    index_t dC_dstate_stride;
-    index_t du_batch_stride;
-    index_t du_d_stride;
-    index_t dz_batch_stride;
-    index_t dz_d_stride;
-    index_t ddelta_batch_stride;
-    index_t ddelta_d_stride;
-
-    // Common data pointers.
-    void *__restrict__ dout_ptr;
-    void *__restrict__ dA_ptr;
-    void *__restrict__ dB_ptr;
-    void *__restrict__ dC_ptr;
-    void *__restrict__ dD_ptr;
-    void *__restrict__ du_ptr;
-    void *__restrict__ dz_ptr;
-    void *__restrict__ ddelta_ptr;
-    void *__restrict__ ddelta_bias_ptr;
-};
diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu b/mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu
deleted file mode 100644
index c55f0e858af4ebd246a5d251308ab920b4e01a50..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu
+++ /dev/null
@@ -1,9 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-// Split into multiple files to compile in paralell
-
-#include "selective_scan_bwd_kernel.cuh"
-
-template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu b/mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu
deleted file mode 100644
index 72adaf5cb13c6429e2f345a0a823c6bc3722b95a..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu
+++ /dev/null
@@ -1,9 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-// Split into multiple files to compile in paralell
-
-#include "selective_scan_bwd_kernel.cuh"
-
-template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu b/mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu
deleted file mode 100644
index df126d7c8d5f9f0862273d2fe21ea15b35757b64..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu
+++ /dev/null
@@ -1,9 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-// Split into multiple files to compile in paralell
-
-#include "selective_scan_bwd_kernel.cuh"
-
-template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu b/mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu
deleted file mode 100644
index 3ff271b50eaff208ae33c16c87ab7aaee76dfd76..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu
+++ /dev/null
@@ -1,9 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-// Split into multiple files to compile in paralell
-
-#include "selective_scan_bwd_kernel.cuh"
-
-template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu b/mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu
deleted file mode 100644
index 5554902342785b289b81c060a71a51734fc1e6bf..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu
+++ /dev/null
@@ -1,9 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-// Split into multiple files to compile in paralell
-
-#include "selective_scan_bwd_kernel.cuh"
-
-template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu b/mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu
deleted file mode 100644
index a7ed642231da80c455c0499702cc8a1cb4536ec2..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu
+++ /dev/null
@@ -1,9 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-// Split into multiple files to compile in paralell
-
-#include "selective_scan_bwd_kernel.cuh"
-
-template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh
deleted file mode 100644
index 2ed101148a4b32560111e5a832fc8b5881a4b243..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh
+++ /dev/null
@@ -1,531 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#pragma once
-
-#include <c10/util/BFloat16.h>
-#include <c10/util/Half.h>
-#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
-#include <ATen/cuda/Atomic.cuh>  // For atomicAdd on complex
-
-#include <cub/block/block_load.cuh>
-#include <cub/block/block_store.cuh>
-#include <cub/block/block_scan.cuh>
-#include <cub/block/block_reduce.cuh>
-
-#include "selective_scan.h"
-#include "selective_scan_common.h"
-#include "reverse_scan.cuh"
-#include "static_switch.h"
-
-template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x);
-template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
-template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }
-
-template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
-         bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
-struct Selective_Scan_bwd_kernel_traits {
-    static_assert(kNItems_ % 4 == 0);
-    using input_t = input_t_;
-    using weight_t = weight_t_;
-    static constexpr int kNThreads = kNThreads_;
-    static constexpr int kNItems = kNItems_;
-    static constexpr int kNBytes = sizeof(input_t);
-    static_assert(kNBytes == 2 || kNBytes == 4);
-    static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
-    static_assert(kNItems % kNElts == 0);
-    static constexpr int kNLoads = kNItems / kNElts;
-    static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
-    static constexpr bool kIsEvenLen = kIsEvenLen_;
-    static constexpr bool kIsVariableB = kIsVariableB_;
-    static constexpr bool kIsVariableC = kIsVariableC_;
-    static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
-    static constexpr bool kHasZ = kHasZ_;
-    // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
-    // For complex this would lead to massive register spilling, so we keep it at 2.
-    static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
-    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
-    using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
-    using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
-    using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
-    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
-    using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
-    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
-    using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
-    using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
-    using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
-    using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
-    using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
-    static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
-                                                 sizeof(typename BlockLoadVecT::TempStorage),
-                                                 (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
-                                                 (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
-                                                 sizeof(typename BlockStoreT::TempStorage),
-                                                 sizeof(typename BlockStoreVecT::TempStorage)});
-    static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
-    static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
-    static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
-};
-
-template<typename Ktraits>
-__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
-void selective_scan_bwd_kernel(SSMParamsBwd params) {
-    constexpr bool kIsComplex = Ktraits::kIsComplex;
-    constexpr bool kIsVariableB = Ktraits::kIsVariableB;
-    constexpr bool kIsVariableC = Ktraits::kIsVariableC;
-    constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
-    constexpr bool kHasZ = Ktraits::kHasZ;
-    constexpr int kNThreads = Ktraits::kNThreads;
-    constexpr int kNItems = Ktraits::kNItems;
-    using input_t = typename Ktraits::input_t;
-    using weight_t = typename Ktraits::weight_t;
-    using scan_t = typename Ktraits::scan_t;
-
-    // Shared memory.
-    extern __shared__ char smem_[];
-    // cast to lvalue reference of expected type
-    // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
-    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
-    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
-    auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
-    auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
-    auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
-    auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
-    auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
-    auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
-    auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
-    auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
-    auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce);
-    auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
-    auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
-    weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
-    scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads);
-    weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE);
-    weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE);
-
-    const int batch_id = blockIdx.x;
-    const int dim_id = blockIdx.y;
-    const int group_id = dim_id / (params.dim_ngroups_ratio);
-    input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
-        + dim_id * params.u_d_stride;
-    input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
-        + dim_id * params.delta_d_stride;
-    input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
-        + dim_id * params.dout_d_stride;
-    weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
-    weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride;
-    input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
-    weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride;
-    input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
-    weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
-    weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
-        + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride);
-    weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
-        + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride);
-    float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
-    float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
-    float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
-    float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
-    scan_t *x = params.x_ptr == nullptr
-        ? nullptr
-        : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
-    float dD_val = 0;
-    float ddelta_bias_val = 0;
-
-    constexpr int kChunkSize = kNThreads * kNItems;
-    u += (params.n_chunks - 1) * kChunkSize;
-    delta += (params.n_chunks - 1) * kChunkSize;
-    dout += (params.n_chunks - 1) * kChunkSize;
-    Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
-    Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
-    for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
-        input_t u_vals[kNItems];
-        input_t delta_vals_load[kNItems];
-        input_t dout_vals_load[kNItems];
-        __syncthreads();
-        load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
-        u -= kChunkSize;
-        __syncthreads();
-        load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
-        // Will reload delta at the same location if kDeltaSoftplus
-        if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
-        __syncthreads();
-        load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
-        dout -= kChunkSize;
-
-        float dout_vals[kNItems], delta_vals[kNItems];
-        #pragma unroll
-        for (int i = 0; i < kNItems; ++i) {
-            dout_vals[i] = float(dout_vals_load[i]);
-            delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
-            if constexpr (kDeltaSoftplus) {
-                delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
-            }
-        }
-
-        if constexpr (kHasZ) {
-            input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
-                + dim_id * params.z_d_stride + chunk * kChunkSize;
-            input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
-                + dim_id * params.out_d_stride + chunk * kChunkSize;
-            input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride
-                + dim_id * params.dz_d_stride + chunk * kChunkSize;
-            input_t z_vals[kNItems], out_vals[kNItems];
-            __syncthreads();
-            load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
-            __syncthreads();
-            load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize);
-            float dz_vals[kNItems], z_silu_vals[kNItems];
-            #pragma unroll
-            for (int i = 0; i < kNItems; ++i) {
-                float z_val = z_vals[i];
-                float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val));
-                z_silu_vals[i] = z_val * z_sigmoid_val;
-                dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val
-                             * (1.0f + z_val * (1.0f - z_sigmoid_val));
-                dout_vals[i] *= z_silu_vals[i];
-            }
-            __syncthreads();
-            store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);
-            if (params.out_z_ptr != nullptr) {  // Recompute and store out_z
-                float out_z_vals[kNItems];
-                #pragma unroll
-                for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }
-                // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
-                    // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);
-                // }
-                input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
-                    + dim_id * params.out_z_d_stride + chunk * kChunkSize;
-                __syncthreads();
-                store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize);
-            }
-        }
-
-        float du_vals[kNItems];
-        #pragma unroll
-        for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
-        #pragma unroll
-        for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
-
-        float ddelta_vals[kNItems] = {0};
-        __syncthreads();
-        for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
-            const weight_t A_val = A[state_idx * params.A_dstate_stride];
-            // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
-            weight_t A_scaled;
-            constexpr float kLog2e = M_LOG2E;
-            if constexpr (!kIsComplex) {
-                A_scaled = A_val * kLog2e;
-            } else {
-                A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_);
-            }
-            weight_t B_val, C_val;
-            weight_t B_vals[kNItems], C_vals[kNItems];
-            if constexpr (!kIsVariableB) {
-                B_val = B[state_idx * params.B_dstate_stride];
-            } else {
-                load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
-                    smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
-            }
-            if constexpr (!kIsVariableC) {
-                C_val = C[state_idx * params.C_dstate_stride];
-            } else {
-                auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
-                load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
-                    smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
-            }
-            // const weight_t A_val = smem_a[state_idx];
-            scan_t thread_data[kNItems], thread_reverse_data[kNItems];
-            if constexpr (!kIsComplex) {
-                #pragma unroll
-                for (int i = 0; i < kNItems; ++i) {
-                    const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
-                    thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
-                    if (i == 0) {
-                        smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
-                    } else {
-                        thread_reverse_data[i - 1].x = delta_a_exp;
-                    }
-                    thread_reverse_data[i].y = dout_vals[i] *
-                        (!kIsVariableC
-                         ? (!kIsVariableB ? B_val * C_val : C_val)
-                         : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
-                }
-                __syncthreads();
-                thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
-                    ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
-                    : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
-                // Initialize running total
-                scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
-                SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
-                Ktraits::BlockScanT(smem_scan).InclusiveScan(
-                    thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
-                );
-                scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
-                SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
-                Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
-                    thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
-                );
-                if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
-                weight_t dA_val = 0, dBC_val = 0;
-                weight_t dB_vals[kNItems], dC_vals[kNItems];
-                #pragma unroll
-                for (int i = 0; i < kNItems; ++i) {
-                    const float dx = thread_reverse_data[i].y;
-                    const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i];
-                    du_vals[i] += ddelta_u * delta_vals[i];
-                    const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
-                    ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
-                    dA_val += dx * delta_vals[i] * a;
-                    if constexpr (!kIsVariableB || !kIsVariableC) {
-                        if constexpr (!kIsVariableB) {  // dBC_val is dB_val
-                            dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);
-                        } else {  // dBC_val is dC_val
-                            dBC_val += dout_vals[i] * thread_data[i].y;
-                        }
-                    }
-                    if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
-                    if constexpr (kIsVariableC) {
-                        dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);
-                    }
-                }
-                // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
-                if constexpr (kIsVariableB || kIsVariableC) {
-                    if constexpr (kIsVariableB) {
-                        Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
-                    }
-                    if constexpr (kIsVariableC) {
-                        auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
-                        Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
-                    }
-                    const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
-                    weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
-                    weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
-                    #pragma unroll
-                    for (int i = 0; i < kNItems; ++i) {
-                        if (i * kNThreads < seqlen_remaining) {
-                            if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
-                            if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
-                        }
-                    }
-                }
-                if constexpr (!kIsVariableB || !kIsVariableC) {
-                    float2 dA_dBC_val = make_float2(dA_val, dBC_val);
-                    dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
-                    dA_val = dA_dBC_val.x;
-                    if (threadIdx.x == 0) {
-                        smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
-                    }
-                } else {
-                    dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
-                }
-                if (threadIdx.x == 0) {
-                    smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
-                }
-            } else {
-                #pragma unroll
-                for (int i = 0; i < kNItems; ++i) {
-                    // Pytorch's implementation of complex exp (which calls thrust) is very slow
-                    complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
-                    weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
-                    thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
-                    if (i == 0) {
-                        smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
-                    } else {
-                        thread_reverse_data[i - 1].x = delta_a_exp.real_;
-                        thread_reverse_data[i - 1].y = -delta_a_exp.imag_;
-                    }
-                    complex_t dout_BC = 2 * dout_vals[i]
-                        * conj(!kIsVariableC
-                                ? (!kIsVariableB ? B_val * C_val : C_val)
-                                : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
-                    thread_reverse_data[i].z = dout_BC.real_;
-                    thread_reverse_data[i].w = dout_BC.imag_;
-                }
-                __syncthreads();
-                complex_t delta_a_exp = threadIdx.x == kNThreads - 1
-                    ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
-                    : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
-                thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;
-                thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;
-                // Initialize running total
-                scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
-                SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
-                Ktraits::BlockScanT(smem_scan).InclusiveScan(
-                    thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
-                );
-                scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
-                SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
-                Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
-                    thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
-                );
-                if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
-                weight_t dA_val = 0, dBC_val = 0;
-                weight_t dB_vals[kNItems], dC_vals[kNItems];
-                #pragma unroll
-                for (int i = 0; i < kNItems; ++i) {
-                    complex_t x = complex_t(thread_data[i].z, thread_data[i].w);
-                    complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);
-                    float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;
-                    if constexpr (!kIsVariableB || !kIsVariableC) {
-                        if constexpr (!kIsVariableB) {  // dBC_val is dB_val
-                            dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);
-                        } else {  // dBC_val is dC_val
-                            dBC_val += (2 * dout_vals[i]) * conj(x);
-                        }
-                    }
-                    const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]));
-                    du_vals[i] += ddelta_u * delta_vals[i];
-                    ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_;
-                    dA_val += delta_vals[i] * dx * a_conj;
-                    if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
-                    if constexpr (kIsVariableC) {
-                        dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);
-                    }
-                }
-                // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
-                if constexpr (kIsVariableB || kIsVariableC) {
-                    float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];
-                    if constexpr (kIsVariableB) {
-                        #pragma unroll
-                        for (int i = 0; i < kNItems; ++i) {
-                            dB_vals_f[i * 2] = dB_vals[i].real_;
-                            dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
-                        }
-                        Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
-                    }
-                    if constexpr (kIsVariableC) {
-                        #pragma unroll
-                        for (int i = 0; i < kNItems; ++i) {
-                            dC_vals_f[i * 2] = dC_vals[i].real_;
-                            dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
-                        }
-                        auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
-                        Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
-                    }
-                    const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
-                    float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
-                    float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
-                    #pragma unroll
-                    for (int i = 0; i < kNItems * 2; ++i) {
-                        if (i * kNThreads < seqlen_remaining) {
-                            if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); }
-                            if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); }
-                        }
-                    }
-                }
-                if constexpr (!kIsVariableB || !kIsVariableC) {
-                    float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
-                    dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
-                    dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
-                    dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
-                    if (threadIdx.x == 0) {
-                        smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
-                    }
-                } else {
-                    dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
-                }
-                if (threadIdx.x == 0) {
-                    smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
-                }
-            }
-        }
-
-        if constexpr (kDeltaSoftplus) {
-            __syncthreads();
-            input_t delta_vals_load[kNItems];
-            load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
-            delta -= kChunkSize;
-            #pragma unroll
-            for (int i = 0; i < kNItems; ++i) {
-                float delta_val = float(delta_vals_load[i]) + delta_bias;
-                float delta_val_neg_exp = expf(-delta_val);
-                ddelta_vals[i] = delta_val <= 20.f
-                    ? ddelta_vals[i] / (1.f + delta_val_neg_exp)
-                    : ddelta_vals[i];
-            }
-        }
-        for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
-
-        input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
-            + dim_id * params.du_d_stride + chunk * kChunkSize;
-        input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
-            + dim_id * params.ddelta_d_stride + chunk * kChunkSize;
-        __syncthreads();
-        store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
-        __syncthreads();
-        store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
-
-        Bvar -= kChunkSize * (!kIsComplex ? 1 : 2);
-        Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
-    }
-    if (params.dD_ptr != nullptr) {
-        dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
-        if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
-    }
-    if (params.ddelta_bias_ptr != nullptr) {
-        __syncthreads();
-        ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
-        if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
-    }
-    for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
-        gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
-        weight_t dBC_val;
-        if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; }
-        if constexpr (!kIsVariableB) {
-            gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]),
-                         !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val);
-        }
-        if constexpr (!kIsVariableC) {
-            gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]),
-                        !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val);
-        }
-    }
-}
-
-template<int kNThreads, int kNItems, typename input_t, typename weight_t>
-void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
-    BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
-        BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
-            BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
-                BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
-                    BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
-                        using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
-                        // using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
-                        // TODO: check this
-                        constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
-                        // printf("smem_size = %d\n", kSmemSize);
-                        dim3 grid(params.batch, params.dim);
-                        auto kernel = &selective_scan_bwd_kernel<Ktraits>;
-                        if (kSmemSize >= 48 * 1024) {
-                            C10_CUDA_CHECK(cudaFuncSetAttribute(
-                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
-                        }
-                        kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
-                        C10_CUDA_KERNEL_LAUNCH_CHECK();
-                    });
-                });
-            });
-        });
-    });
-}
-
-template<typename input_t, typename weight_t>
-void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
-    if (params.seqlen <= 128) {
-        selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
-    } else if (params.seqlen <= 256) {
-        selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
-    } else if (params.seqlen <= 512) {
-        selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
-    } else if (params.seqlen <= 1024) {
-        selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
-    } else {
-        selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
-    }
-}
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_common.h b/mamba/csrc/selective_scan/selective_scan_common.h
deleted file mode 100644
index 9140dcdf3b68ad2de95bcd3fd9543a9d320cef68..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_common.h
+++ /dev/null
@@ -1,221 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#pragma once
-
-#include <cuda_bf16.h>
-#include <cuda_fp16.h>
-#include <c10/util/complex.h>  // For scalar_value_type
-
-#define MAX_DSTATE 256
-
-using complex_t = c10::complex<float>;
-
-inline __device__ float2 operator+(const float2 & a, const float2 & b){
-    return {a.x + b.x, a.y + b.y};
-}
-
-inline __device__ float3 operator+(const float3 &a, const float3 &b) {
-  return {a.x + b.x, a.y + b.y, a.z + b.z};
-}
-
-inline __device__ float4 operator+(const float4 & a, const float4 & b){
-    return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template<int BYTES> struct BytesToType {};
-
-template<> struct BytesToType<16> {
-    using Type = uint4;
-    static_assert(sizeof(Type) == 16);
-};
-
-template<> struct BytesToType<8> {
-    using Type = uint64_t;
-    static_assert(sizeof(Type) == 8);
-};
-
-template<> struct BytesToType<4> {
-    using Type = uint32_t;
-    static_assert(sizeof(Type) == 4);
-};
-
-template<> struct BytesToType<2> {
-    using Type = uint16_t;
-    static_assert(sizeof(Type) == 2);
-};
-
-template<> struct BytesToType<1> {
-    using Type = uint8_t;
-    static_assert(sizeof(Type) == 1);
-};
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template<typename scalar_t, int N>
-struct Converter{
-    static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
-        #pragma unroll
-        for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
-    }
-};
-
-template<int N>
-struct Converter<at::Half, N>{
-    static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
-        static_assert(N % 2 == 0);
-        auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
-        auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
-        #pragma unroll
-        for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
-    }
-};
-
-#if __CUDA_ARCH__ >= 800
-template<int N>
-struct Converter<at::BFloat16, N>{
-    static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
-        static_assert(N % 2 == 0);
-        auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
-        auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
-        #pragma unroll
-        for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
-    }
-};
-#endif
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
-// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
-__device__ __forceinline__ complex_t cexp2f(complex_t z) {
-    float t = exp2f(z.real_);
-    float c, s;
-    sincosf(z.imag_, &s, &c);
-    return complex_t(c * t, s * t);
-}
-
-__device__ __forceinline__ complex_t cexpf(complex_t z) {
-    float t = expf(z.real_);
-    float c, s;
-    sincosf(z.imag_, &s, &c);
-    return complex_t(c * t, s * t);
-}
-
-template<typename scalar_t> struct SSMScanOp;
-
-template<>
-struct SSMScanOp<float> {
-    __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
-        return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
-    }
-};
-
-template<>
-struct SSMScanOp<complex_t> {
-    __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
-        complex_t a0 = complex_t(ab0.x, ab0.y);
-        complex_t b0 = complex_t(ab0.z, ab0.w);
-        complex_t a1 = complex_t(ab1.x, ab1.y);
-        complex_t b1 = complex_t(ab1.z, ab1.w);
-        complex_t out_a = a1 * a0;
-        complex_t out_b = a1 * b0 + b1;
-        return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
-    }
-};
-
-// A stateful callback functor that maintains a running prefix to be applied
-// during consecutive scan operations.
-template <typename scalar_t> struct SSMScanPrefixCallbackOp {
-    using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
-    scan_t running_prefix;
-    // Constructor
-    __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
-    // Callback operator to be entered by the first warp of threads in the block.
-    // Thread-0 is responsible for returning a value for seeding the block-wide scan.
-    __device__ scan_t operator()(scan_t block_aggregate) {
-        scan_t old_prefix = running_prefix;
-        running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
-        return old_prefix;
-    }
-};
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template<typename Ktraits>
-inline __device__ void load_input(typename Ktraits::input_t *u,
-                                  typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
-                                  typename Ktraits::BlockLoadT::TempStorage &smem_load,
-                                  int seqlen) {
-    if constexpr (Ktraits::kIsEvenLen) {
-        auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
-        using vec_t = typename Ktraits::vec_t;
-        Ktraits::BlockLoadVecT(smem_load_vec).Load(
-            reinterpret_cast<vec_t*>(u),
-            reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
-       );
-    } else {
-        Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
-    }
-}
-
-template<typename Ktraits>
-inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
-                                   typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
-                                   typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
-                                   int seqlen) {
-    constexpr int kNItems = Ktraits::kNItems;
-    if constexpr (!Ktraits::kIsComplex) {
-        typename Ktraits::input_t B_vals_load[kNItems];
-        if constexpr (Ktraits::kIsEvenLen) {
-            auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
-            using vec_t = typename Ktraits::vec_t;
-            Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
-                reinterpret_cast<vec_t*>(Bvar),
-                reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
-          );
-        } else {
-            Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
-        }
-        // #pragma unroll
-        // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
-        Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
-    } else {
-        typename Ktraits::input_t B_vals_load[kNItems * 2];
-        if constexpr (Ktraits::kIsEvenLen) {
-            auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
-            using vec_t = typename Ktraits::vec_t;
-            Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
-                reinterpret_cast<vec_t*>(Bvar),
-                reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
-          );
-        } else {
-            Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
-        }
-        #pragma unroll
-        for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
-    }
-}
-
-template<typename Ktraits>
-inline __device__ void store_output(typename Ktraits::input_t *out,
-                                    const float (&out_vals)[Ktraits::kNItems],
-                                    typename Ktraits::BlockStoreT::TempStorage &smem_store,
-                                    int seqlen) {
-    typename Ktraits::input_t write_vals[Ktraits::kNItems];
-    #pragma unroll
-    for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
-    if constexpr (Ktraits::kIsEvenLen) {
-        auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
-        using vec_t = typename Ktraits::vec_t;
-        Ktraits::BlockStoreVecT(smem_store_vec).Store(
-            reinterpret_cast<vec_t*>(out),
-            reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
-       );
-    } else {
-        Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
-    }
-}
diff --git a/mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu b/mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu
deleted file mode 100644
index 2b8615b1d522c119125d4cb6ff3dce42f2bd4659..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu
+++ /dev/null
@@ -1,10 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-// Split into multiple files to compile in paralell
-
-#include "selective_scan_fwd_kernel.cuh"
-
-template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
-template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu b/mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu
deleted file mode 100644
index 015e2a0eff633daf2693e43a2648008652a38c7c..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu
+++ /dev/null
@@ -1,10 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-// Split into multiple files to compile in paralell
-
-#include "selective_scan_fwd_kernel.cuh"
-
-template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
-template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu b/mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu
deleted file mode 100644
index c142fe0208ea784679122ba04997d3432b05efcc..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu
+++ /dev/null
@@ -1,10 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-// Split into multiple files to compile in paralell
-
-#include "selective_scan_fwd_kernel.cuh"
-
-template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);
-template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase &params, cudaStream_t stream);
\ No newline at end of file
diff --git a/mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh
deleted file mode 100644
index 440a209108bfe120c73d123bbf0b82ccf43a5638..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh
+++ /dev/null
@@ -1,345 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
-
-#pragma once
-
-#include <c10/util/BFloat16.h>
-#include <c10/util/Half.h>
-#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
-
-#include <cub/block/block_load.cuh>
-#include <cub/block/block_store.cuh>
-#include <cub/block/block_scan.cuh>
-
-#include "selective_scan.h"
-#include "selective_scan_common.h"
-#include "static_switch.h"
-
-template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
-         bool kIsVariableB_, bool kIsVariableC_,
-         bool kHasZ_, typename input_t_, typename weight_t_>
-struct Selective_Scan_fwd_kernel_traits {
-    static_assert(kNItems_ % 4 == 0);
-    using input_t = input_t_;
-    using weight_t = weight_t_;
-    static constexpr int kNThreads = kNThreads_;
-    // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
-    static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
-    static constexpr int kNItems = kNItems_;
-    static constexpr int kNRows = kNRows_;
-    static constexpr int kNBytes = sizeof(input_t);
-    static_assert(kNBytes == 2 || kNBytes == 4);
-    static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
-    static_assert(kNItems % kNElts == 0);
-    static constexpr int kNLoads = kNItems / kNElts;
-    static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
-    static constexpr bool kIsEvenLen = kIsEvenLen_;
-    static constexpr bool kIsVariableB = kIsVariableB_;
-    static constexpr bool kIsVariableC = kIsVariableC_;
-    static constexpr bool kHasZ = kHasZ_;
-
-    static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
-
-    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
-    using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
-    using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
-        !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
-    using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
-    using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,
-        !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE  : cub::BLOCK_LOAD_DIRECT>;
-    using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
-    using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
-        !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
-    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
-    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
-    using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
-    static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
-                                                 sizeof(typename BlockLoadVecT::TempStorage),
-                                                 (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
-                                                 (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
-                                                 sizeof(typename BlockStoreT::TempStorage),
-                                                 sizeof(typename BlockStoreVecT::TempStorage)});
-    static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
-};
-
-template<typename Ktraits>
-__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
-void selective_scan_fwd_kernel(SSMParamsBase params) {
-    constexpr bool kIsComplex = Ktraits::kIsComplex;
-    constexpr bool kIsVariableB = Ktraits::kIsVariableB;
-    constexpr bool kIsVariableC = Ktraits::kIsVariableC;
-    constexpr bool kHasZ = Ktraits::kHasZ;
-    constexpr int kNThreads = Ktraits::kNThreads;
-    constexpr int kNItems = Ktraits::kNItems;
-    constexpr int kNRows = Ktraits::kNRows;
-    constexpr bool kDirectIO = Ktraits::kDirectIO;
-    using input_t = typename Ktraits::input_t;
-    using weight_t = typename Ktraits::weight_t;
-    using scan_t = typename Ktraits::scan_t;
-
-    // Shared memory.
-    extern __shared__ char smem_[];
-    // cast to lvalue reference of expected type
-    // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
-    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
-    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
-    auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
-    auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
-    auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
-    auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
-    auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
-    // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
-    // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
-    scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
-
-    const int batch_id = blockIdx.x;
-    const int dim_id = blockIdx.y;
-    const int group_id = dim_id / (params.dim_ngroups_ratio);
-    input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
-        + dim_id * kNRows * params.u_d_stride;
-    input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
-        + dim_id * kNRows * params.delta_d_stride;
-    weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
-    weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
-    input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
-    weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
-    input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
-    scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
-
-    float D_val[kNRows] = {0};
-    if (params.D_ptr != nullptr) {
-        #pragma unroll
-        for (int r = 0; r < kNRows; ++r) {
-            D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
-        }
-    }
-    float delta_bias[kNRows] = {0};
-    if (params.delta_bias_ptr != nullptr) {
-        #pragma unroll
-        for (int r = 0; r < kNRows; ++r) {
-            delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
-        }
-    }
-
-    // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
-    //     smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
-    //     smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
-    // }
-
-    constexpr int kChunkSize = kNThreads * kNItems;
-    for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
-        input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
-        __syncthreads();
-        #pragma unroll
-        for (int r = 0; r < kNRows; ++r) {
-            if constexpr (!kDirectIO) {
-                if (r > 0) { __syncthreads(); }
-            }
-            load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
-            if constexpr (!kDirectIO) { __syncthreads(); }
-            load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
-        }
-        u += kChunkSize;
-        delta += kChunkSize;
-
-        float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
-        #pragma unroll
-        for (int r = 0; r < kNRows; ++r) {
-            #pragma unroll
-            for (int i = 0; i < kNItems; ++i) {
-                float u_val = float(u_vals[r][i]);
-                delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
-                if (params.delta_softplus) {
-                    delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
-                }
-                delta_u_vals[r][i] = delta_vals[r][i] * u_val;
-                out_vals[r][i] = D_val[r] * u_val;
-            }
-        }
-
-        __syncthreads();
-        for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
-            weight_t A_val[kNRows];
-            #pragma unroll
-            for (int r = 0; r < kNRows; ++r) {
-                A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
-                // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
-                constexpr float kLog2e = M_LOG2E;
-                if constexpr (!kIsComplex) {
-                    A_val[r] *= kLog2e;
-                } else {
-                    A_val[r].real_ *= kLog2e;
-                }
-            }
-            // This variable holds B * C if both B and C are constant across seqlen. If only B varies
-            // across seqlen, this holds C. If only C varies across seqlen, this holds B.
-            // If both B and C vary, this is unused.
-            weight_t BC_val[kNRows];
-            weight_t B_vals[kNItems], C_vals[kNItems];
-            if constexpr (kIsVariableB) {
-                load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
-                    smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
-                if constexpr (!kIsVariableC) {
-                    #pragma unroll
-                    for (int r = 0; r < kNRows; ++r) {
-                        BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
-                    }
-                }
-            }
-            if constexpr (kIsVariableC) {
-                auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
-                load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
-                    smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
-                if constexpr (!kIsVariableB) {
-                    #pragma unroll
-                    for (int r = 0; r < kNRows; ++r) {
-                        BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
-                    }
-                }
-            }
-            if constexpr (!kIsVariableB && !kIsVariableC) {
-                #pragma unroll
-                for (int r = 0; r < kNRows; ++r) {
-                    BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
-                }
-            }
-
-            #pragma unroll
-            for (int r = 0; r < kNRows; ++r) {
-                if (r > 0) { __syncthreads(); }  // Scan could be using the same smem
-                scan_t thread_data[kNItems];
-                #pragma unroll
-                for (int i = 0; i < kNItems; ++i) {
-                    if constexpr (!kIsComplex) {
-                        thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
-                                                     !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
-                        if constexpr (!Ktraits::kIsEvenLen) {  // So that the last state is correct
-                            if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
-                                thread_data[i] = make_float2(1.f, 0.f);
-                            }
-                        }
-                    } else {
-                        // Pytorch's implementation of complex exp (which calls thrust) is very slow
-                        complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
-                        weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
-                        thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
-                        if constexpr (!Ktraits::kIsEvenLen) {  // So that the last state is correct
-                            if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
-                                thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
-                            }
-                        }
-                    }
-                }
-                // Initialize running total
-                scan_t running_prefix;
-                if constexpr (!kIsComplex) {
-                    // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
-                    running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
-                    // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
-                } else {
-                    running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
-                    // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
-                }
-                SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
-                Ktraits::BlockScanT(smem_scan).InclusiveScan(
-                    thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
-                );
-                // There's a syncthreads in the scan op, so we don't need to sync here.
-                // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
-                if (threadIdx.x == 0) {
-                    smem_running_prefix[state_idx] = prefix_op.running_prefix;
-                    x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
-                }
-                #pragma unroll
-                for (int i = 0; i < kNItems; ++i) {
-                    const weight_t C_val = !kIsVariableC
-                        ? BC_val[r]
-                        : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
-                    if constexpr (!kIsComplex) {
-                        out_vals[r][i] += thread_data[i].y * C_val;
-                    } else {
-                        out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;
-                    }
-                }
-            }
-        }
-
-        input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
-            + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
-        __syncthreads();
-        #pragma unroll
-        for (int r = 0; r < kNRows; ++r) {
-            if constexpr (!kDirectIO) {
-                if (r > 0) { __syncthreads(); }
-            }
-            store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
-        }
-
-        if constexpr (kHasZ) {
-            input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
-                + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
-            input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
-                + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
-            #pragma unroll
-            for (int r = 0; r < kNRows; ++r) {
-                input_t z_vals[kNItems];
-                __syncthreads();
-                load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
-                #pragma unroll
-                for (int i = 0; i < kNItems; ++i) {
-                    float z_val = z_vals[i];
-                    out_vals[r][i] *= z_val / (1 + expf(-z_val));
-                }
-                __syncthreads();
-                store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
-            }
-        }
-
-        Bvar += kChunkSize * (!kIsComplex ? 1 : 2);
-        Cvar += kChunkSize * (!kIsComplex ? 1 : 2);
-    }
-}
-
-template<int kNThreads, int kNItems, typename input_t, typename weight_t>
-void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
-    // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
-    // processing 1 row.
-    constexpr int kNRows = 1;
-    BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
-        BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
-            BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
-                BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
-                    using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
-                    // constexpr int kSmemSize = Ktraits::kSmemSize;
-                    constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
-                    // printf("smem_size = %d\n", kSmemSize);
-                    dim3 grid(params.batch, params.dim / kNRows);
-                    auto kernel = &selective_scan_fwd_kernel<Ktraits>;
-                    if (kSmemSize >= 48 * 1024) {
-                        C10_CUDA_CHECK(cudaFuncSetAttribute(
-                            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
-                    }
-                    kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
-                    C10_CUDA_KERNEL_LAUNCH_CHECK();
-                });
-            });
-        });
-    });
-}
-
-template<typename input_t, typename weight_t>
-void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
-    if (params.seqlen <= 128) {
-        selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
-    } else if (params.seqlen <= 256) {
-        selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
-    } else if (params.seqlen <= 512) {
-        selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
-    } else if (params.seqlen <= 1024) {
-        selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
-    } else {
-        selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
-    }
-}
diff --git a/mamba/csrc/selective_scan/static_switch.h b/mamba/csrc/selective_scan/static_switch.h
deleted file mode 100644
index 7920ac045d0a2a1f4c4159ee3eebe51fe1e2c203..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/static_switch.h
+++ /dev/null
@@ -1,25 +0,0 @@
-// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
-// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
-
-#pragma once
-
-/// @param COND       - a boolean expression to switch by
-/// @param CONST_NAME - a name given for the constexpr bool variable.
-/// @param ...       - code to execute for true and false
-///
-/// Usage:
-/// ```
-/// BOOL_SWITCH(flag, BoolConst, [&] {
-///     some_function<BoolConst>(...);
-/// });
-/// ```
-#define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
-    [&] {                                                                            \
-        if (COND) {                                                                  \
-            constexpr bool CONST_NAME = true;                                        \
-            return __VA_ARGS__();                                                    \
-        } else {                                                                     \
-            constexpr bool CONST_NAME = false;                                       \
-            return __VA_ARGS__();                                                    \
-        }                                                                            \
-    }()
diff --git a/mamba/csrc/selective_scan/uninitialized_copy.cuh b/mamba/csrc/selective_scan/uninitialized_copy.cuh
deleted file mode 100644
index 630622dddcc9041737307810000584a843a01764..0000000000000000000000000000000000000000
--- a/mamba/csrc/selective_scan/uninitialized_copy.cuh
+++ /dev/null
@@ -1,69 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2011-2022, NVIDIA CORPORATION.  All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- *     * Redistributions of source code must retain the above copyright
- *       notice, this list of conditions and the following disclaimer.
- *     * Redistributions in binary form must reproduce the above copyright
- *       notice, this list of conditions and the following disclaimer in the
- *       documentation and/or other materials provided with the distribution.
- *     * Neither the name of the NVIDIA CORPORATION nor the
- *       names of its contributors may be used to endorse or promote products
- *       derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-#include <cub/config.cuh>
-
-#include <cuda/std/type_traits>
-
-
-namespace detail
-{
-
-#if defined(_NVHPC_CUDA)
-template <typename T, typename U>
-__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
-{
-  // NVBug 3384810
-  new (ptr) T(::cuda::std::forward<U>(val));
-}
-#else
-template <typename T,
-          typename U,
-          typename ::cuda::std::enable_if<
-            ::cuda::std::is_trivially_copyable<T>::value,
-            int
-          >::type = 0>
-__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
-{
-  *ptr = ::cuda::std::forward<U>(val);
-}
-
-template <typename T,
-         typename U,
-         typename ::cuda::std::enable_if<
-           !::cuda::std::is_trivially_copyable<T>::value,
-           int
-         >::type = 0>
-__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
-{
-  new (ptr) T(::cuda::std::forward<U>(val));
-}
-#endif
-
-} // namespace detail
diff --git a/mamba/evals/lm_harness_eval.py b/mamba/evals/lm_harness_eval.py
deleted file mode 100644
index d09d40534cf53be4d1387666697c82aa53add625..0000000000000000000000000000000000000000
--- a/mamba/evals/lm_harness_eval.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import torch
-
-import transformers
-from transformers import AutoTokenizer
-
-from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
-
-from lm_eval.api.model import LM
-from lm_eval.models.huggingface import HFLM
-from lm_eval.api.registry import register_model
-from lm_eval.__main__ import cli_evaluate
-
-
-@register_model("mamba")
-class MambaEvalWrapper(HFLM):
-
-    AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
-
-    def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
-                 dtype=torch.float16):
-        LM.__init__(self)
-        self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
-        self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
-        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
-        self.vocab_size = self.tokenizer.vocab_size
-        self._batch_size = batch_size if batch_size is None else 64
-        self._max_length = max_length
-        self._device = torch.device(device)
-
-    @property
-    def batch_size(self):
-        return self._batch_size
-
-    def _model_generate(self, context, max_length, stop, **generation_kwargs):
-        raise NotImplementedError()
-
-
-if __name__ == "__main__":
-    cli_evaluate()
diff --git a/mamba/mamba_ssm/__init__.py b/mamba/mamba_ssm/__init__.py
deleted file mode 100644
index 2ecd144db5dbec72bcfcdcea28c624a7e2bf053b..0000000000000000000000000000000000000000
--- a/mamba/mamba_ssm/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-__version__ = "1.0.1"
-
-from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn
-from mamba_ssm.modules.mamba_simple import Mamba
-from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
diff --git a/mamba/mamba_ssm/models/__init__.py b/mamba/mamba_ssm/models/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/mamba/mamba_ssm/models/mixer_seq_simple.py b/mamba/mamba_ssm/models/mixer_seq_simple.py
deleted file mode 100644
index 383f773f1f700cd53176e51327a5d8dc58158da0..0000000000000000000000000000000000000000
--- a/mamba/mamba_ssm/models/mixer_seq_simple.py
+++ /dev/null
@@ -1,233 +0,0 @@
-# Copyright (c) 2023, Albert Gu, Tri Dao.
-
-import math
-from functools import partial
-
-from collections import namedtuple
-
-import torch
-import torch.nn as nn
-
-from mamba_ssm.modules.mamba_simple import Mamba, Block
-from mamba_ssm.utils.generation import GenerationMixin
-from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
-
-try:
-    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
-except ImportError:
-    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
-
-
-def create_block(
-    d_model,
-    ssm_cfg=None,
-    norm_epsilon=1e-5,
-    rms_norm=False,
-    residual_in_fp32=False,
-    fused_add_norm=False,
-    layer_idx=None,
-    device=None,
-    dtype=None,
-):
-    if ssm_cfg is None:
-        ssm_cfg = {}
-    factory_kwargs = {"device": device, "dtype": dtype}
-    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
-    norm_cls = partial(
-        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
-    )
-    block = Block(
-        d_model,
-        mixer_cls,
-        norm_cls=norm_cls,
-        fused_add_norm=fused_add_norm,
-        residual_in_fp32=residual_in_fp32,
-    )
-    block.layer_idx = layer_idx
-    return block
-
-
-# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
-def _init_weights(
-    module,
-    n_layer,
-    initializer_range=0.02,  # Now only used for embedding layer.
-    rescale_prenorm_residual=True,
-    n_residuals_per_layer=1,  # Change to 2 if we have MLP
-):
-    if isinstance(module, nn.Linear):
-        if module.bias is not None:
-            if not getattr(module.bias, "_no_reinit", False):
-                nn.init.zeros_(module.bias)
-    elif isinstance(module, nn.Embedding):
-        nn.init.normal_(module.weight, std=initializer_range)
-
-    if rescale_prenorm_residual:
-        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
-        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
-        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
-        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
-        #
-        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
-        for name, p in module.named_parameters():
-            if name in ["out_proj.weight", "fc2.weight"]:
-                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
-                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
-                # We need to reinit p since this code could be called multiple times
-                # Having just p *= scale would repeatedly scale it down
-                nn.init.kaiming_uniform_(p, a=math.sqrt(5))
-                with torch.no_grad():
-                    p /= math.sqrt(n_residuals_per_layer * n_layer)
-
-
-class MixerModel(nn.Module):
-    def __init__(
-        self,
-        d_model: int,
-        n_layer: int,
-        vocab_size: int,
-        ssm_cfg=None,
-        norm_epsilon: float = 1e-5,
-        rms_norm: bool = False,
-        initializer_cfg=None,
-        fused_add_norm=False,
-        residual_in_fp32=False,
-        device=None,
-        dtype=None,
-    ) -> None:
-        factory_kwargs = {"device": device, "dtype": dtype}
-        super().__init__()
-        self.residual_in_fp32 = residual_in_fp32
-
-        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
-
-        # We change the order of residual and layer norm:
-        # Instead of LN -> Attn / MLP -> Add, we do:
-        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
-        # the main branch (output of MLP / Mixer). The model definition is unchanged.
-        # This is for performance reason: we can fuse add + layer_norm.
-        self.fused_add_norm = fused_add_norm
-        if self.fused_add_norm:
-            if layer_norm_fn is None or rms_norm_fn is None:
-                raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
-
-        self.layers = nn.ModuleList(
-            [
-                create_block(
-                    d_model,
-                    ssm_cfg=ssm_cfg,
-                    norm_epsilon=norm_epsilon,
-                    rms_norm=rms_norm,
-                    residual_in_fp32=residual_in_fp32,
-                    fused_add_norm=fused_add_norm,
-                    layer_idx=i,
-                    **factory_kwargs,
-                )
-                for i in range(n_layer)
-            ]
-        )
-
-        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
-            d_model, eps=norm_epsilon, **factory_kwargs
-        )
-
-        self.apply(
-            partial(
-                _init_weights,
-                n_layer=n_layer,
-                **(initializer_cfg if initializer_cfg is not None else {}),
-            )
-        )
-
-    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
-        return {
-            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
-            for i, layer in enumerate(self.layers)
-        }
-
-    def forward(self, input_ids, inference_params=None):
-        hidden_states = self.embedding(input_ids)
-        residual = None
-        for layer in self.layers:
-            hidden_states, residual = layer(
-                hidden_states, residual, inference_params=inference_params
-            )
-        if not self.fused_add_norm:
-            residual = (hidden_states + residual) if residual is not None else hidden_states
-            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
-        else:
-            # Set prenorm=False here since we don't need the residual
-            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
-            hidden_states = fused_add_norm_fn(
-                hidden_states,
-                self.norm_f.weight,
-                self.norm_f.bias,
-                eps=self.norm_f.eps,
-                residual=residual,
-                prenorm=False,
-                residual_in_fp32=self.residual_in_fp32,
-            )
-        return hidden_states
-
-
-class MambaLMHeadModel(nn.Module, GenerationMixin):
-
-    def __init__(
-        self,
-        d_model: int,
-        n_layer: int,
-        vocab_size: int,
-        initializer_cfg=None,
-        pad_vocab_size_multiple: int = 1,
-        device=None,
-        dtype=None,
-        **backbone_kwargs,
-    ) -> None:
-        factory_kwargs = {"device": device, "dtype": dtype}
-        super().__init__()
-        if vocab_size % pad_vocab_size_multiple != 0:
-            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
-        self.backbone = MixerModel(
-            d_model=d_model,
-            n_layer=n_layer,
-            vocab_size=vocab_size,
-            initializer_cfg=initializer_cfg,
-            **backbone_kwargs,
-            **factory_kwargs,
-        )
-        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
-
-        # Initialize weights and apply final processing
-        self.apply(
-            partial(
-                _init_weights,
-                n_layer=n_layer,
-                **(initializer_cfg if initializer_cfg is not None else {}),
-            )
-        )
-        self.tie_weights()
-
-    def tie_weights(self):
-        self.lm_head.weight = self.backbone.embedding.weight
-
-    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
-        return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
-
-    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
-        """
-        "position_ids" is just to be compatible with Transformer generation. We don't use it.
-        num_last_tokens: if > 0, only return the logits for the last n tokens
-        """
-        hidden_states = self.backbone(input_ids, inference_params=inference_params)
-        if num_last_tokens > 0:
-            hidden_states = hidden_states[:, -num_last_tokens:]
-        lm_logits = self.lm_head(hidden_states)
-        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
-        return CausalLMOutput(logits=lm_logits)
-
-    @classmethod
-    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
-        config = load_config_hf(pretrained_model_name)
-        model = cls(**config, device=device, dtype=dtype, **kwargs)
-        model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
-        return model
diff --git a/mamba/mamba_ssm/modules/__init__.py b/mamba/mamba_ssm/modules/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/mamba/mamba_ssm/modules/mamba_simple.py b/mamba/mamba_ssm/modules/mamba_simple.py
deleted file mode 100644
index 2a1dd8f808b50d632bbd22f0648d4cb8939cb1e1..0000000000000000000000000000000000000000
--- a/mamba/mamba_ssm/modules/mamba_simple.py
+++ /dev/null
@@ -1,418 +0,0 @@
-# Copyright (c) 2023, Tri Dao, Albert Gu.
-
-import math
-from typing import Optional
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch import Tensor
-
-from einops import rearrange, repeat
-
-try:
-    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
-except ImportError:
-    causal_conv1d_fn, causal_conv1d_update = None
-
-try:
-    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
-except ImportError:
-    selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None
-
-try:
-    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
-except ImportError:
-    selective_state_update = None
-
-try:
-    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
-except ImportError:
-    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
-
-
-class Mamba(nn.Module):
-    def __init__(
-        self,
-        d_model,
-        d_state=16,
-        d_conv=4,
-        expand=2,
-        dt_rank="auto",
-        dt_min=0.001,
-        dt_max=0.1,
-        dt_init="random",
-        dt_scale=1.0,
-        dt_init_floor=1e-4,
-        conv_bias=True,
-        bias=False,
-        use_fast_path=True,  # Fused kernel options
-        layer_idx=None,
-        device=None,
-        dtype=None,
-        bimamba=True,
-    ):
-        factory_kwargs = {"device": device, "dtype": dtype}
-        super().__init__()
-        self.d_model = d_model
-        self.d_state = d_state
-        self.d_conv = d_conv
-        self.expand = expand
-        self.d_inner = int(self.expand * self.d_model)
-        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
-        self.use_fast_path = use_fast_path
-        self.layer_idx = layer_idx
-        self.bimamba = bimamba
-
-        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
-
-        self.conv1d = nn.Conv1d(
-            in_channels=self.d_inner,
-            out_channels=self.d_inner,
-            bias=conv_bias,
-            kernel_size=d_conv,
-            groups=self.d_inner,
-            padding=d_conv - 1,
-            **factory_kwargs,
-        )
-
-        self.activation = "silu"
-        self.act = nn.SiLU()
-
-        self.x_proj = nn.Linear(
-            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
-        )
-        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
-
-        # Initialize special dt projection to preserve variance at initialization
-        dt_init_std = self.dt_rank**-0.5 * dt_scale
-        if dt_init == "constant":
-            nn.init.constant_(self.dt_proj.weight, dt_init_std)
-        elif dt_init == "random":
-            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
-        else:
-            raise NotImplementedError
-
-        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
-        dt = torch.exp(
-            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
-            + math.log(dt_min)
-        ).clamp(min=dt_init_floor)
-        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
-        inv_dt = dt + torch.log(-torch.expm1(-dt))
-        with torch.no_grad():
-            self.dt_proj.bias.copy_(inv_dt)
-        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
-        self.dt_proj.bias._no_reinit = True
-
-        # S4D real initialization
-        # NOTE: why plus 1?
-        A = repeat(
-            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
-            "n -> d n",
-            d=self.d_inner,
-        ).contiguous()
-        A_log = torch.log(A)  # Keep A_log in fp32
-        self.A_log = nn.Parameter(A_log)
-        self.A_log._no_weight_decay = True
-
-        # D "skip" parameter
-        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
-        self.D._no_weight_decay = True
-
-        # bidirectional
-        # forked from https://github.com/hustvl/Vim
-        if self.bimamba:
-            A_b = repeat(
-                torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
-                "n -> d n",
-                d=self.d_inner,
-            ).contiguous()
-            A_b_log = torch.log(A_b)  # Keep A_b_log in fp32
-            self.A_b_log = nn.Parameter(A_b_log)
-            self.A_b_log._no_weight_decay = True 
-
-            self.conv1d_b = nn.Conv1d(
-                in_channels=self.d_inner,
-                out_channels=self.d_inner,
-                bias=conv_bias,
-                kernel_size=d_conv,
-                groups=self.d_inner,
-                padding=d_conv - 1,
-                **factory_kwargs,
-            )
-
-            self.x_proj_b = nn.Linear(
-                self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
-            )
-            self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
-
-            self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
-            self.D_b._no_weight_decay = True
-
-        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
-
-    def forward(self, hidden_states, inference_params=None, T=1):
-        """
-        hidden_states: (B, L, D)
-        Returns: same shape as hidden_states
-        """
-        batch, seqlen, dim = hidden_states.shape
-
-        conv_state, ssm_state = None, None
-        if inference_params is not None:
-            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
-            if inference_params.seqlen_offset > 0:
-                # The states are updated inplace
-                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
-                return out
-
-        # We do matmul and transpose BLH -> HBL at the same time
-        # NOTE: same as in_proj(hidden_states) but memory-efficient with the following operations
-        xz = rearrange(
-            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
-            "d (b l) -> b d l",
-            l=seqlen,
-        )
-        if self.in_proj.bias is not None:
-            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
-
-        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
-        # In the backward pass we write dx and dz next to each other to avoid torch.cat
-        if self.use_fast_path and inference_params is None:  # Doesn't support outputting the states
-            if self.bimamba:
-                A_b = -torch.exp(self.A_b_log.float())
-                out = mamba_inner_fn_no_out_proj(
-                    xz,
-                    self.conv1d.weight,
-                    self.conv1d.bias,
-                    self.x_proj.weight,
-                    self.dt_proj.weight,
-                    A,
-                    None,  # input-dependent B
-                    None,  # input-dependent C
-                    self.D.float(),
-                    delta_bias=self.dt_proj.bias.float(),
-                    delta_softplus=True,
-                )
-                out_b = mamba_inner_fn_no_out_proj(
-                    xz.flip([-1]),
-                    self.conv1d_b.weight,
-                    self.conv1d_b.bias,
-                    self.x_proj_b.weight,
-                    self.dt_proj_b.weight,
-                    A_b,
-                    None,
-                    None,
-                    self.D_b.float(),
-                    delta_bias=self.dt_proj_b.bias.float(),
-                    delta_softplus=True,
-                )
-                out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
-            else:
-                out = mamba_inner_fn(
-                    xz,
-                    self.conv1d.weight,
-                    self.conv1d.bias,
-                    self.x_proj.weight,
-                    self.dt_proj.weight,
-                    self.out_proj.weight,
-                    self.out_proj.bias,
-                    A,
-                    None,  # input-dependent B
-                    None,  # input-dependent C
-                    self.D.float(),
-                    delta_bias=self.dt_proj.bias.float(),
-                    delta_softplus=True,
-                )
-        else:
-            x, z = xz.chunk(2, dim=1)
-            # Compute short convolution
-            if conv_state is not None:
-                conv_state.copy_(x[:, :, -self.d_conv :])  # Update state (B D W)
-            if causal_conv1d_fn is None:
-                x = self.act(self.conv1d(x)[..., :seqlen])
-            else:
-                assert self.activation in ["silu", "swish"]
-                x = causal_conv1d_fn(
-                    x,
-                    rearrange(self.conv1d.weight, "d 1 w -> d w"),
-                    self.conv1d.bias,
-                    self.activation,
-                )
-
-            # We're careful here about the layout, to avoid extra transposes.
-            # We want dt to have d as the slowest moving dimension
-            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
-            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
-            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
-            dt = self.dt_proj.weight @ dt.t()
-            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
-            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
-            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
-            assert self.activation in ["silu", "swish"]
-            y = selective_scan_fn(
-                x,
-                dt,
-                A,
-                B,
-                C,
-                self.D.float(),
-                z=z,
-                delta_bias=self.dt_proj.bias.float(),
-                delta_softplus=True,
-                return_last_state=ssm_state is not None,
-            )
-            if ssm_state is not None:
-                y, last_state = y
-                ssm_state.copy_(last_state)
-            y = rearrange(y, "b d l -> b l d")
-            out = self.out_proj(y)
-        return out
-
-    def step(self, hidden_states, conv_state, ssm_state):
-        dtype = hidden_states.dtype
-        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
-        xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)
-        x, z = xz.chunk(2, dim=-1)  # (B D)
-
-        # Conv step
-        if causal_conv1d_update is None:
-            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)
-            conv_state[:, :, -1] = x
-            x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)  # (B D)
-            if self.conv1d.bias is not None:
-                x = x + self.conv1d.bias
-            x = self.act(x).to(dtype=dtype)
-        else:
-            x = causal_conv1d_update(
-                x,
-                conv_state,
-                rearrange(self.conv1d.weight, "d 1 w -> d w"),
-                self.conv1d.bias,
-                self.activation,
-            )
-
-        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)
-        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
-        # Don't add dt_bias here
-        dt = F.linear(dt, self.dt_proj.weight)  # (B d_inner)
-        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
-
-        # SSM step
-        if selective_state_update is None:
-            # Discretize A and B
-            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
-            dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
-            dB = torch.einsum("bd,bn->bdn", dt, B)
-            ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
-            y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
-            y = y + self.D.to(dtype) * x
-            y = y * self.act(z)  # (B D)
-        else:
-            y = selective_state_update(
-                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
-            )
-
-        out = self.out_proj(y)
-        return out.unsqueeze(1), conv_state, ssm_state
-
-    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
-        device = self.out_proj.weight.device
-        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
-        conv_state = torch.zeros(
-            batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
-        )
-        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
-        # ssm_dtype = torch.float32
-        ssm_state = torch.zeros(
-            batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
-        )
-        return conv_state, ssm_state
-
-    def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
-        assert self.layer_idx is not None
-        if self.layer_idx not in inference_params.key_value_memory_dict:
-            batch_shape = (batch_size,)
-            conv_state = torch.zeros(
-                batch_size,
-                self.d_model * self.expand,
-                self.d_conv,
-                device=self.conv1d.weight.device,
-                dtype=self.conv1d.weight.dtype,
-            )
-            ssm_state = torch.zeros(
-                batch_size,
-                self.d_model * self.expand,
-                self.d_state,
-                device=self.dt_proj.weight.device,
-                dtype=self.dt_proj.weight.dtype,
-                # dtype=torch.float32,
-            )
-            inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
-        else:
-            conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
-            # TODO: What if batch size changes between generation, and we reuse the same states?
-            if initialize_states:
-                conv_state.zero_()
-                ssm_state.zero_()
-        return conv_state, ssm_state
-
-
-class Block(nn.Module):
-    def __init__(
-        self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
-    ):
-        """
-        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
-
-        This Block has a slightly different structure compared to a regular
-        prenorm Transformer block.
-        The standard block is: LN -> MHA/MLP -> Add.
-        [Ref: https://arxiv.org/abs/2002.04745]
-        Here we have: Add -> LN -> Mixer, returning both
-        the hidden_states (output of the mixer) and the residual.
-        This is purely for performance reasons, as we can fuse add and LayerNorm.
-        The residual needs to be provided (except for the very first block).
-        """
-        super().__init__()
-        self.residual_in_fp32 = residual_in_fp32
-        self.fused_add_norm = fused_add_norm
-        self.mixer = mixer_cls(dim)
-        self.norm = norm_cls(dim)
-        if self.fused_add_norm:
-            assert RMSNorm is not None, "RMSNorm import fails"
-            assert isinstance(
-                self.norm, (nn.LayerNorm, RMSNorm)
-            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
-
-    def forward(
-        self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
-    ):
-        r"""Pass the input through the encoder layer.
-
-        Args:
-            hidden_states: the sequence to the encoder layer (required).
-            residual: hidden_states = Mixer(LN(residual))
-        """
-        if not self.fused_add_norm:
-            residual = (hidden_states + residual) if residual is not None else hidden_states
-            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
-            if self.residual_in_fp32:
-                residual = residual.to(torch.float32)
-        else:
-            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
-            hidden_states, residual = fused_add_norm_fn(
-                hidden_states,
-                self.norm.weight,
-                self.norm.bias,
-                residual=residual,
-                prenorm=True,
-                residual_in_fp32=self.residual_in_fp32,
-                eps=self.norm.eps,
-            )
-        hidden_states = self.mixer(hidden_states, inference_params=inference_params)
-        return hidden_states, residual
-
-    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
-        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
diff --git a/mamba/mamba_ssm/ops/__init__.py b/mamba/mamba_ssm/ops/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/mamba/mamba_ssm/ops/selective_scan_interface.py b/mamba/mamba_ssm/ops/selective_scan_interface.py
deleted file mode 100644
index 99b455ed949c123bb453922d5ac88d00f401e392..0000000000000000000000000000000000000000
--- a/mamba/mamba_ssm/ops/selective_scan_interface.py
+++ /dev/null
@@ -1,709 +0,0 @@
-# Copyright (c) 2023, Tri Dao, Albert Gu.
-
-import torch
-import torch.nn.functional as F
-from torch.cuda.amp import custom_bwd, custom_fwd
-
-from einops import rearrange, repeat
-
-from causal_conv1d import causal_conv1d_fn
-import causal_conv1d_cuda
-import selective_scan_cuda
-
-
-class SelectiveScanFn(torch.autograd.Function):
-
-    @staticmethod
-    def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
-                return_last_state=False):
-        if u.stride(-1) != 1:
-            u = u.contiguous()
-        if delta.stride(-1) != 1:
-            delta = delta.contiguous()
-        if D is not None:
-            D = D.contiguous()
-        if B.stride(-1) != 1:
-            B = B.contiguous()
-        if C.stride(-1) != 1:
-            C = C.contiguous()
-        if z is not None and z.stride(-1) != 1:
-            z = z.contiguous()
-        if B.dim() == 3:
-            B = rearrange(B, "b dstate l -> b 1 dstate l")
-            ctx.squeeze_B = True
-        if C.dim() == 3:
-            C = rearrange(C, "b dstate l -> b 1 dstate l")
-            ctx.squeeze_C = True
-        out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
-        ctx.delta_softplus = delta_softplus
-        ctx.has_z = z is not None
-        last_state = x[:, :, -1, 1::2]  # (batch, dim, dstate)
-        if not ctx.has_z:
-            ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
-            return out if not return_last_state else (out, last_state)
-        else:
-            ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
-            out_z = rest[0]
-            return out_z if not return_last_state else (out_z, last_state)
-
-    @staticmethod
-    def backward(ctx, dout, *args):
-        if not ctx.has_z:
-            u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
-            z = None
-            out = None
-        else:
-            u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
-        if dout.stride(-1) != 1:
-            dout = dout.contiguous()
-        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
-        # backward of selective_scan_cuda with the backward of chunk).
-        # Here we just pass in None and dz will be allocated in the C++ code.
-        du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
-            u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
-            False  # option to recompute out_z, not used here
-        )
-        dz = rest[0] if ctx.has_z else None
-        dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
-        dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
-        return (du, ddelta, dA, dB, dC,
-                dD if D is not None else None,
-                dz,
-                ddelta_bias if delta_bias is not None else None,
-                None,
-                None)
-
-
-def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
-                     return_last_state=False):
-    """if return_last_state is True, returns (out, last_state)
-    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
-    not considered in the backward pass.
-    """
-    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
-
-
-def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
-                      return_last_state=False):
-    """
-    u: r(B D L)
-    delta: r(B D L)
-    A: c(D N) or r(D N)
-    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
-    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
-    D: r(D)
-    z: r(B D L)
-    delta_bias: r(D), fp32
-
-    out: r(B D L)
-    last_state (optional): r(B D dstate) or c(B D dstate)
-    """
-    dtype_in = u.dtype
-    u = u.float()
-    delta = delta.float()
-    if delta_bias is not None:
-        delta = delta + delta_bias[..., None].float()
-    if delta_softplus:
-        delta = F.softplus(delta)
-    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
-    is_variable_B = B.dim() >= 3
-    is_variable_C = C.dim() >= 3
-    if A.is_complex():
-        if is_variable_B:
-            B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
-        if is_variable_C:
-            C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
-    else:
-        B = B.float()
-        C = C.float()
-    x = A.new_zeros((batch, dim, dstate))
-    ys = []
-    deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
-    if not is_variable_B:
-        deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
-    else:
-        if B.dim() == 3:
-            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
-        else:
-            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
-            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
-    if is_variable_C and C.dim() == 4:
-        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
-    last_state = None
-    for i in range(u.shape[2]):
-        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
-        if not is_variable_C:
-            y = torch.einsum('bdn,dn->bd', x, C)
-        else:
-            if C.dim() == 3:
-                y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
-            else:
-                y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
-        if i == u.shape[2] - 1:
-            last_state = x
-        if y.is_complex():
-            y = y.real * 2
-        ys.append(y)
-    y = torch.stack(ys, dim=2) # (batch dim L)
-    out = y if D is None else y + u * rearrange(D, "d -> d 1")
-    if z is not None:
-        out = out * F.silu(z)
-    out = out.to(dtype=dtype_in)
-    return out if not return_last_state else (out, last_state)
-
-
-class MambaInnerFnNoOutProj(torch.autograd.Function):
-
-    @staticmethod
-    @custom_fwd
-    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
-        """
-             xz: (batch, dim, seqlen)
-        """
-        assert checkpoint_lvl in [0, 1]
-        L = xz.shape[-1]
-        delta_rank = delta_proj_weight.shape[1]
-        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
-        if torch.is_autocast_enabled():
-            x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
-            delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
-        if xz.stride(-1) != 1:
-            xz = xz.contiguous()
-        conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
-        x, z = xz.chunk(2, dim=1)
-        conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
-        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
-        # We're being very careful here about the layout, to avoid extra transposes.
-        # We want delta to have d as the slowest moving dimension
-        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
-        x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
-        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
-        ctx.is_variable_B = B is None
-        ctx.is_variable_C = C is None
-        ctx.B_proj_bias_is_None = B_proj_bias is None
-        ctx.C_proj_bias_is_None = C_proj_bias is None
-        if B is None:  # variable B
-            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
-            if B_proj_bias is not None:
-                B = B + B_proj_bias.to(dtype=B.dtype)
-            if not A.is_complex():
-                # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
-                B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
-            else:
-                B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
-        else:
-            if B.stride(-1) != 1:
-                B = B.contiguous()
-        if C is None:  # variable C
-            C = x_dbl[:, -d_state:]  # (bl dstate)
-            if C_proj_bias is not None:
-                C = C + C_proj_bias.to(dtype=C.dtype)
-            if not A.is_complex():
-                # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
-                C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
-            else:
-                C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
-        else:
-            if C.stride(-1) != 1:
-                C = C.contiguous()
-        if D is not None:
-            D = D.contiguous()
-        out, scan_intermediates, out_z = selective_scan_cuda.fwd(
-            conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
-        )
-        ctx.delta_softplus = delta_softplus
-        ctx.checkpoint_lvl = checkpoint_lvl
-        if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
-            conv1d_out, delta = None, None
-        ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
-                              delta_proj_weight, conv1d_out, delta,
-                              A, B, C, D, delta_bias, scan_intermediates, out)
-        # return rearrange(out_z, "b d l -> b l d")
-        return out_z
-
-    @staticmethod
-    @custom_bwd
-    def backward(ctx, dout):
-        # dout: (batch, seqlen, dim)
-        (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, 
-         conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
-        L = xz.shape[-1]
-        delta_rank = delta_proj_weight.shape[1]
-        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
-        x, z = xz.chunk(2, dim=1)
-        if dout.stride(-1) != 1:
-            dout = dout.contiguous()
-        if ctx.checkpoint_lvl == 1:
-            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
-            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
-                              "d (b l) -> b d l", l = L)
-        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
-        # backward of selective_scan_cuda with the backward of chunk).
-        dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
-        dx, dz = dxz.chunk(2, dim=1)
-        # dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l
-        dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
-            conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz,
-            ctx.delta_softplus,
-            True  # option to recompute out_z
-        )
-        dD = dD if D is not None else None
-        dx_dbl = torch.empty_like(x_dbl)
-        dB_proj_bias = None
-        if ctx.is_variable_B:
-            if not A.is_complex():
-                dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
-            else:
-                dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
-            dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
-            dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
-            dB = None
-        dC_proj_bias = None
-        if ctx.is_variable_C:
-            if not A.is_complex():
-                dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
-            else:
-                dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
-            dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
-            dx_dbl[:, -d_state:] = dC  # (bl d)
-            dC = None
-        ddelta = rearrange(ddelta, "b d l -> d (b l)")
-        ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
-        dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
-        dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
-        dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
-        dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
-        dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
-        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
-        # backward of conv1d with the backward of chunk).
-        dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
-            x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True
-        )
-        dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
-        dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
-        return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
-                dA, dB, dC, dD,
-                ddelta_bias if delta_bias is not None else None,
-                dB_proj_bias, dC_proj_bias, None)
-    
-
-class MambaInnerFn(torch.autograd.Function):
-
-    @staticmethod
-    @custom_fwd
-    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-                out_proj_weight, out_proj_bias,
-                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
-        """
-             xz: (batch, dim, seqlen)
-        """
-        assert checkpoint_lvl in [0, 1]
-        L = xz.shape[-1]
-        delta_rank = delta_proj_weight.shape[1]
-        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
-        if torch.is_autocast_enabled():
-            x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
-            delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
-            out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
-            out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
-                             if out_proj_bias is not None else None)
-        if xz.stride(-1) != 1:
-            xz = xz.contiguous()
-        conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
-        x, z = xz.chunk(2, dim=1)
-        conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
-        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
-        # We're being very careful here about the layout, to avoid extra transposes.
-        # We want delta to have d as the slowest moving dimension
-        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
-        x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
-        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
-        ctx.is_variable_B = B is None
-        ctx.is_variable_C = C is None
-        ctx.B_proj_bias_is_None = B_proj_bias is None
-        ctx.C_proj_bias_is_None = C_proj_bias is None
-        if B is None:  # variable B
-            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
-            if B_proj_bias is not None:
-                B = B + B_proj_bias.to(dtype=B.dtype)
-            if not A.is_complex():
-                # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
-                B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
-            else:
-                B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
-        else:
-            if B.stride(-1) != 1:
-                B = B.contiguous()
-        if C is None:  # variable C
-            C = x_dbl[:, -d_state:]  # (bl dstate)
-            if C_proj_bias is not None:
-                C = C + C_proj_bias.to(dtype=C.dtype)
-            if not A.is_complex():
-                # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
-                C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
-            else:
-                C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
-        else:
-            if C.stride(-1) != 1:
-                C = C.contiguous()
-        if D is not None:
-            D = D.contiguous()
-        out, scan_intermediates, out_z = selective_scan_cuda.fwd(
-            conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
-        )
-        ctx.delta_softplus = delta_softplus
-        ctx.out_proj_bias_is_None = out_proj_bias is None
-        ctx.checkpoint_lvl = checkpoint_lvl
-        if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
-            conv1d_out, delta = None, None
-        ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
-                              delta_proj_weight, out_proj_weight, conv1d_out, delta,
-                              A, B, C, D, delta_bias, scan_intermediates, out)
-        return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
-
-    @staticmethod
-    @custom_bwd
-    def backward(ctx, dout):
-        # dout: (batch, seqlen, dim)
-        (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
-         conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
-        L = xz.shape[-1]
-        delta_rank = delta_proj_weight.shape[1]
-        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
-        x, z = xz.chunk(2, dim=1)
-        if dout.stride(-1) != 1:
-            dout = dout.contiguous()
-        if ctx.checkpoint_lvl == 1:
-            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
-            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
-                              "d (b l) -> b d l", l = L)
-        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
-        # backward of selective_scan_cuda with the backward of chunk).
-        dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
-        dx, dz = dxz.chunk(2, dim=1)
-        dout = rearrange(dout, "b l e -> e (b l)")
-        dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
-        dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
-            conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
-            ctx.delta_softplus,
-            True  # option to recompute out_z
-        )
-        dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
-        dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
-        dD = dD if D is not None else None
-        dx_dbl = torch.empty_like(x_dbl)
-        dB_proj_bias = None
-        if ctx.is_variable_B:
-            if not A.is_complex():
-                dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
-            else:
-                dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
-            dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
-            dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
-            dB = None
-        dC_proj_bias = None
-        if ctx.is_variable_C:
-            if not A.is_complex():
-                dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
-            else:
-                dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
-            dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
-            dx_dbl[:, -d_state:] = dC  # (bl d)
-            dC = None
-        ddelta = rearrange(ddelta, "b d l -> d (b l)")
-        ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
-        dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
-        dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
-        dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
-        dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
-        dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
-        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
-        # backward of conv1d with the backward of chunk).
-        dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
-            x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True
-        )
-        dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
-        dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
-        return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
-                dout_proj_weight, dout_proj_bias,
-                dA, dB, dC, dD,
-                ddelta_bias if delta_bias is not None else None,
-                dB_proj_bias, dC_proj_bias, None)
-
-
-class BiMambaInnerFn(torch.autograd.Function):
-
-    @staticmethod
-    @custom_fwd
-    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-                out_proj_weight, out_proj_bias,
-                A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
-        """
-             xz: (batch, dim, seqlen)
-        """
-        assert checkpoint_lvl in [0, 1]
-        L = xz.shape[-1]
-        delta_rank = delta_proj_weight.shape[1]
-        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
-        if torch.is_autocast_enabled():
-            x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
-            delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
-            out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
-            out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
-                             if out_proj_bias is not None else None)
-        if xz.stride(-1) != 1:
-            xz = xz.contiguous()
-        conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
-        x, z = xz.chunk(2, dim=1)
-        conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
-        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
-        # We're being very careful here about the layout, to avoid extra transposes.
-        # We want delta to have d as the slowest moving dimension
-        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
-        x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
-        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
-        ctx.is_variable_B = B is None
-        ctx.is_variable_C = C is None
-        ctx.B_proj_bias_is_None = B_proj_bias is None
-        ctx.C_proj_bias_is_None = C_proj_bias is None
-        if B is None:  # variable B
-            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
-            if B_proj_bias is not None:
-                B = B + B_proj_bias.to(dtype=B.dtype)
-            if not A.is_complex():
-                # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
-                B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
-            else:
-                B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
-        else:
-            if B.stride(-1) != 1:
-                B = B.contiguous()
-        if C is None:  # variable C
-            C = x_dbl[:, -d_state:]  # (bl dstate)
-            if C_proj_bias is not None:
-                C = C + C_proj_bias.to(dtype=C.dtype)
-            if not A.is_complex():
-                # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
-                C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
-            else:
-                C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
-        else:
-            if C.stride(-1) != 1:
-                C = C.contiguous()
-        if D is not None:
-            D = D.contiguous()
-        out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd(
-            conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
-        )
-        assert not A_b.is_complex(), "A should not be complex!!"
-        out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd(
-            conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus,
-        )
-
-        out_z = out_z_f + out_z_b.flip([-1])
-
-        ctx.delta_softplus = delta_softplus
-        ctx.out_proj_bias_is_None = out_proj_bias is None
-        ctx.checkpoint_lvl = checkpoint_lvl
-        if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
-            conv1d_out, delta = None, None
-        ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
-                              delta_proj_weight, out_proj_weight, conv1d_out, delta,
-                              A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b)
-        return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
-
-    @staticmethod
-    @custom_bwd
-    def backward(ctx, dout):
-        # dout: (batch, seqlen, dim)
-        (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
-         conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) = ctx.saved_tensors
-        L = xz.shape[-1]
-        delta_rank = delta_proj_weight.shape[1]
-        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
-        x, z = xz.chunk(2, dim=1)
-        if dout.stride(-1) != 1:
-            dout = dout.contiguous()
-        if ctx.checkpoint_lvl == 1:
-            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True)
-            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
-                              "d (b l) -> b d l", l = L)
-        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
-        # backward of selective_scan_cuda with the backward of chunk).
-        dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
-        dx, dz = dxz.chunk(2, dim=1)
-        dout = rearrange(dout, "b l e -> e (b l)")
-        dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
-        dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd(
-            conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz,
-            ctx.delta_softplus,
-            True  # option to recompute out_z
-        )
-        # flip one
-        dz_b = torch.empty_like(dz)
-        dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd(
-            conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b,
-            ctx.delta_softplus,
-            True  # option to recompute out_z
-        )
-
-        dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1])
-        ddelta = ddelta + ddelta_f_b.flip([-1])
-        dB = dB + dB_f_b.flip([-1])
-        dC = dC + dC_f_b.flip([-1])
-        dD = dD + dD_b
-        ddelta_bias = ddelta_bias + ddelta_bias_b
-        dz = dz + dz_b.flip([-1])
-        out_z = out_z_f + out_z_b.flip([-1])
-        
-        dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
-        dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
-        dD = dD if D is not None else None
-        dx_dbl = torch.empty_like(x_dbl)
-        dB_proj_bias = None
-        if ctx.is_variable_B:
-            if not A.is_complex():
-                dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
-            else:
-                dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
-            dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
-            dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
-            dB = None
-        dC_proj_bias = None
-        if ctx.is_variable_C:
-            if not A.is_complex():
-                dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
-            else:
-                dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
-            dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
-            dx_dbl[:, -d_state:] = dC  # (bl d)
-            dC = None
-        ddelta = rearrange(ddelta, "b d l -> d (b l)")
-        ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
-        dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
-        dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
-        dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
-        dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
-        dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
-        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
-        # backward of conv1d with the backward of chunk).
-        dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
-            x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True
-        )
-        dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
-        dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
-        return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
-                dout_proj_weight, dout_proj_bias,
-                dA, dA_b, dB, dC, dD,
-                ddelta_bias if delta_bias is not None else None,
-                dB_proj_bias, dC_proj_bias, None)
-    
-
-def mamba_inner_fn(
-    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-    out_proj_weight, out_proj_bias,
-    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-    C_proj_bias=None, delta_softplus=True
-):
-    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-                              out_proj_weight, out_proj_bias,
-                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
-
-def bimamba_inner_fn(
-    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-    out_proj_weight, out_proj_bias,
-    A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-    C_proj_bias=None, delta_softplus=True
-):
-    return BiMambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-                              out_proj_weight, out_proj_bias,
-                              A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
-
-
-def mamba_inner_fn_no_out_proj(
-    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-    C_proj_bias=None, delta_softplus=True
-):
-    return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
-
-
-def mamba_inner_ref(
-    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-    out_proj_weight, out_proj_bias,
-    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-    C_proj_bias=None, delta_softplus=True
-):
-    L = xz.shape[-1]
-    delta_rank = delta_proj_weight.shape[1]
-    d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
-    x, z = xz.chunk(2, dim=1)
-    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
-    # We're being very careful here about the layout, to avoid extra transposes.
-    # We want delta to have d as the slowest moving dimension
-    # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
-    x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
-    delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
-    delta = rearrange(delta, "d (b l) -> b d l", l=L)
-    if B is None:  # variable B
-        B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
-        if B_proj_bias is not None:
-            B = B + B_proj_bias.to(dtype=B.dtype)
-        if not A.is_complex():
-            B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
-        else:
-            B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
-    if C is None:  # variable B
-        C = x_dbl[:, -d_state:]  # (bl d)
-        if C_proj_bias is not None:
-            C = C + C_proj_bias.to(dtype=C.dtype)
-        if not A.is_complex():
-            C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
-        else:
-            C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
-    y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
-    return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
-
-
-def bimamba_inner_ref(
-    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-    out_proj_weight, out_proj_bias,
-    A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-    C_proj_bias=None, delta_softplus=True
-):
-    L = xz.shape[-1]
-    delta_rank = delta_proj_weight.shape[1]
-    d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
-    x, z = xz.chunk(2, dim=1)
-    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
-    # We're being very careful here about the layout, to avoid extra transposes.
-    # We want delta to have d as the slowest moving dimension
-    # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
-    x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
-    delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
-    delta = rearrange(delta, "d (b l) -> b d l", l=L)
-    if B is None:  # variable B
-        B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
-        if B_proj_bias is not None:
-            B = B + B_proj_bias.to(dtype=B.dtype)
-        if not A.is_complex():
-            B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
-        else:
-            B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
-    if C is None:  # variable B
-        C = x_dbl[:, -d_state:]  # (bl d)
-        if C_proj_bias is not None:
-            C = C + C_proj_bias.to(dtype=C.dtype)
-        if not A.is_complex():
-            C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
-        else:
-            C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
-    y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
-    y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True)
-    y = y + y_b.flip([-1])
-    return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
diff --git a/mamba/mamba_ssm/ops/triton/__init__.py b/mamba/mamba_ssm/ops/triton/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/mamba/mamba_ssm/ops/triton/layernorm.py b/mamba/mamba_ssm/ops/triton/layernorm.py
deleted file mode 100644
index 70d57397f6e5af1138e8df62629f9ab57174f6f4..0000000000000000000000000000000000000000
--- a/mamba/mamba_ssm/ops/triton/layernorm.py
+++ /dev/null
@@ -1,636 +0,0 @@
-# Copyright (c) 2023, Tri Dao.
-# Implement residual + layer_norm / rms_norm.
-
-# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
-# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
-# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
-# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
-
-import math
-
-import torch
-import torch.nn.functional as F
-from torch.cuda.amp import custom_fwd, custom_bwd
-
-import triton
-import triton.language as tl
-
-
-def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
-    dtype = x.dtype
-    if upcast:
-        weight = weight.float()
-        bias = bias.float() if bias is not None else None
-    if upcast:
-        x = x.float()
-        residual = residual.float() if residual is not None else residual
-    if residual is not None:
-        x = (x + residual).to(x.dtype)
-    out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
-        dtype
-    )
-    return out if not prenorm else (out, x)
-
-
-def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
-    dtype = x.dtype
-    if upcast:
-        weight = weight.float()
-        bias = bias.float() if bias is not None else None
-    if upcast:
-        x = x.float()
-        residual = residual.float() if residual is not None else residual
-    if residual is not None:
-        x = (x + residual).to(x.dtype)
-    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
-    out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
-    out = out.to(dtype)
-    return out if not prenorm else (out, x)
-
-
-@triton.autotune(
-    configs=[
-        triton.Config({}, num_warps=1),
-        triton.Config({}, num_warps=2),
-        triton.Config({}, num_warps=4),
-        triton.Config({}, num_warps=8),
-        triton.Config({}, num_warps=16),
-        triton.Config({}, num_warps=32),
-    ],
-    key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
-)
-# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
-# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
-@triton.jit
-def _layer_norm_fwd_1pass_kernel(
-    X,  # pointer to the input
-    Y,  # pointer to the output
-    W,  # pointer to the weights
-    B,  # pointer to the biases
-    RESIDUAL,  # pointer to the residual
-    RESIDUAL_OUT,  # pointer to the residual
-    Mean,  # pointer to the mean
-    Rstd,  # pointer to the 1/std
-    stride_x_row,  # how much to increase the pointer when moving by 1 row
-    stride_y_row,
-    stride_res_row,
-    stride_res_out_row,
-    N,  # number of columns in X
-    eps,  # epsilon to avoid division by zero
-    IS_RMS_NORM: tl.constexpr,
-    BLOCK_N: tl.constexpr,
-    HAS_RESIDUAL: tl.constexpr,
-    STORE_RESIDUAL_OUT: tl.constexpr,
-    HAS_BIAS: tl.constexpr,
-):
-    # Map the program id to the row of X and Y it should compute.
-    row = tl.program_id(0)
-    X += row * stride_x_row
-    Y += row * stride_y_row
-    if HAS_RESIDUAL:
-        RESIDUAL += row * stride_res_row
-    if STORE_RESIDUAL_OUT:
-        RESIDUAL_OUT += row * stride_res_out_row
-    # Compute mean and variance
-    cols = tl.arange(0, BLOCK_N)
-    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
-    if HAS_RESIDUAL:
-        residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
-        x += residual
-    if STORE_RESIDUAL_OUT:
-        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
-    if not IS_RMS_NORM:
-        mean = tl.sum(x, axis=0) / N
-        tl.store(Mean + row, mean)
-        xbar = tl.where(cols < N, x - mean, 0.0)
-        var = tl.sum(xbar * xbar, axis=0) / N
-    else:
-        xbar = tl.where(cols < N, x, 0.0)
-        var = tl.sum(xbar * xbar, axis=0) / N
-    rstd = 1 / tl.sqrt(var + eps)
-    tl.store(Rstd + row, rstd)
-    # Normalize and apply linear transformation
-    mask = cols < N
-    w = tl.load(W + cols, mask=mask).to(tl.float32)
-    if HAS_BIAS:
-        b = tl.load(B + cols, mask=mask).to(tl.float32)
-    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
-    y = x_hat * w + b if HAS_BIAS else x_hat * w
-    # Write output
-    tl.store(Y + cols, y, mask=mask)
-
-
-def _layer_norm_fwd(
-    x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
-):
-    if residual is not None:
-        residual_dtype = residual.dtype
-    M, N = x.shape
-    assert x.stride(-1) == 1
-    if residual is not None:
-        assert residual.stride(-1) == 1
-        assert residual.shape == (M, N)
-    assert weight.shape == (N,)
-    assert weight.stride(-1) == 1
-    if bias is not None:
-        assert bias.stride(-1) == 1
-        assert bias.shape == (N,)
-    # allocate output
-    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
-    assert y.stride(-1) == 1
-    if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
-        residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
-        assert residual_out.stride(-1) == 1
-    else:
-        residual_out = None
-    mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
-    rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
-    # Less than 64KB per feature: enqueue fused kernel
-    MAX_FUSED_SIZE = 65536 // x.element_size()
-    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
-    if N > BLOCK_N:
-        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
-    # heuristics for number of warps
-    with torch.cuda.device(x.device.index):
-        _layer_norm_fwd_1pass_kernel[(M,)](
-            x,
-            y,
-            weight,
-            bias,
-            residual,
-            residual_out,
-            mean,
-            rstd,
-            x.stride(0),
-            y.stride(0),
-            residual.stride(0) if residual is not None else 0,
-            residual_out.stride(0) if residual_out is not None else 0,
-            N,
-            eps,
-            is_rms_norm,
-            BLOCK_N,
-            residual is not None,
-            residual_out is not None,
-            bias is not None,
-        )
-    # residual_out is None if residual is None and residual_dtype == input_dtype
-    return y, mean, rstd, residual_out if residual_out is not None else x
-
-
-@triton.autotune(
-    configs=[
-        triton.Config({}, num_warps=1),
-        triton.Config({}, num_warps=2),
-        triton.Config({}, num_warps=4),
-        triton.Config({}, num_warps=8),
-        triton.Config({}, num_warps=16),
-        triton.Config({}, num_warps=32),
-    ],
-    key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
-)
-# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
-# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
-# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
-@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
-@triton.jit
-def _layer_norm_bwd_kernel(
-    X,  # pointer to the input
-    W,  # pointer to the weights
-    B,  # pointer to the biases
-    Y,  # pointer to the output to be recomputed
-    DY,  # pointer to the output gradient
-    DX,  # pointer to the input gradient
-    DW,  # pointer to the partial sum of weights gradient
-    DB,  # pointer to the partial sum of biases gradient
-    DRESIDUAL,
-    DRESIDUAL_IN,
-    Mean,  # pointer to the mean
-    Rstd,  # pointer to the 1/std
-    stride_x_row,  # how much to increase the pointer when moving by 1 row
-    stride_y_row,
-    stride_dy_row,
-    stride_dx_row,
-    stride_dres_row,
-    stride_dres_in_row,
-    M,  # number of rows in X
-    N,  # number of columns in X
-    eps,  # epsilon to avoid division by zero
-    rows_per_program,
-    IS_RMS_NORM: tl.constexpr,
-    BLOCK_N: tl.constexpr,
-    HAS_DRESIDUAL: tl.constexpr,
-    STORE_DRESIDUAL: tl.constexpr,
-    HAS_BIAS: tl.constexpr,
-    RECOMPUTE_OUTPUT: tl.constexpr,
-):
-    # Map the program id to the elements of X, DX, and DY it should compute.
-    row_block_id = tl.program_id(0)
-    row_start = row_block_id * rows_per_program
-    cols = tl.arange(0, BLOCK_N)
-    mask = cols < N
-    X += row_start * stride_x_row
-    if HAS_DRESIDUAL:
-        DRESIDUAL += row_start * stride_dres_row
-    if STORE_DRESIDUAL:
-        DRESIDUAL_IN += row_start * stride_dres_in_row
-    DY += row_start * stride_dy_row
-    DX += row_start * stride_dx_row
-    if RECOMPUTE_OUTPUT:
-        Y += row_start * stride_y_row
-    w = tl.load(W + cols, mask=mask).to(tl.float32)
-    if RECOMPUTE_OUTPUT and HAS_BIAS:
-        b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
-    dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
-    if HAS_BIAS:
-        db = tl.zeros((BLOCK_N,), dtype=tl.float32)
-    row_end = min((row_block_id + 1) * rows_per_program, M)
-    for row in range(row_start, row_end):
-        # Load data to SRAM
-        x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
-        dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
-        if not IS_RMS_NORM:
-            mean = tl.load(Mean + row)
-        rstd = tl.load(Rstd + row)
-        # Compute dx
-        xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
-        xhat = tl.where(mask, xhat, 0.0)
-        if RECOMPUTE_OUTPUT:
-            y = xhat * w + b if HAS_BIAS else xhat * w
-            tl.store(Y + cols, y, mask=mask)
-        wdy = w * dy
-        dw += dy * xhat
-        if HAS_BIAS:
-            db += dy
-        if not IS_RMS_NORM:
-            c1 = tl.sum(xhat * wdy, axis=0) / N
-            c2 = tl.sum(wdy, axis=0) / N
-            dx = (wdy - (xhat * c1 + c2)) * rstd
-        else:
-            c1 = tl.sum(xhat * wdy, axis=0) / N
-            dx = (wdy - xhat * c1) * rstd
-        if HAS_DRESIDUAL:
-            dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
-            dx += dres
-        # Write dx
-        if STORE_DRESIDUAL:
-            tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
-        tl.store(DX + cols, dx, mask=mask)
-
-        X += stride_x_row
-        if HAS_DRESIDUAL:
-            DRESIDUAL += stride_dres_row
-        if STORE_DRESIDUAL:
-            DRESIDUAL_IN += stride_dres_in_row
-        if RECOMPUTE_OUTPUT:
-            Y += stride_y_row
-        DY += stride_dy_row
-        DX += stride_dx_row
-    tl.store(DW + row_block_id * N + cols, dw, mask=mask)
-    if HAS_BIAS:
-        tl.store(DB + row_block_id * N + cols, db, mask=mask)
-
-
-def _layer_norm_bwd(
-    dy,
-    x,
-    weight,
-    bias,
-    eps,
-    mean,
-    rstd,
-    dresidual=None,
-    has_residual=False,
-    is_rms_norm=False,
-    x_dtype=None,
-    recompute_output=False,
-):
-    M, N = x.shape
-    assert x.stride(-1) == 1
-    assert dy.stride(-1) == 1
-    assert dy.shape == (M, N)
-    if dresidual is not None:
-        assert dresidual.stride(-1) == 1
-        assert dresidual.shape == (M, N)
-    assert weight.shape == (N,)
-    assert weight.stride(-1) == 1
-    if bias is not None:
-        assert bias.stride(-1) == 1
-        assert bias.shape == (N,)
-    # allocate output
-    dx = (
-        torch.empty_like(x)
-        if x_dtype is None
-        else torch.empty(M, N, dtype=x_dtype, device=x.device)
-    )
-    dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
-    y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
-
-    # Less than 64KB per feature: enqueue fused kernel
-    MAX_FUSED_SIZE = 65536 // x.element_size()
-    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
-    if N > BLOCK_N:
-        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
-    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
-    _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
-    _db = (
-        torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
-        if bias is not None
-        else None
-    )
-    rows_per_program = math.ceil(M / sm_count)
-    grid = (sm_count,)
-    with torch.cuda.device(x.device.index):
-        _layer_norm_bwd_kernel[grid](
-            x,
-            weight,
-            bias,
-            y,
-            dy,
-            dx,
-            _dw,
-            _db,
-            dresidual,
-            dresidual_in,
-            mean,
-            rstd,
-            x.stride(0),
-            0 if not recompute_output else y.stride(0),
-            dy.stride(0),
-            dx.stride(0),
-            dresidual.stride(0) if dresidual is not None else 0,
-            dresidual_in.stride(0) if dresidual_in is not None else 0,
-            M,
-            N,
-            eps,
-            rows_per_program,
-            is_rms_norm,
-            BLOCK_N,
-            dresidual is not None,
-            dresidual_in is not None,
-            bias is not None,
-        )
-    dw = _dw.sum(0).to(weight.dtype)
-    db = _db.sum(0).to(bias.dtype) if bias is not None else None
-    # Don't need to compute dresidual_in separately in this case
-    if has_residual and dx.dtype == x.dtype:
-        dresidual_in = dx
-    return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
-
-
-class LayerNormFn(torch.autograd.Function):
-    @staticmethod
-    def forward(
-        ctx,
-        x,
-        weight,
-        bias,
-        residual=None,
-        eps=1e-6,
-        prenorm=False,
-        residual_in_fp32=False,
-        is_rms_norm=False,
-    ):
-        x_shape_og = x.shape
-        # reshape input data into 2D tensor
-        x = x.reshape(-1, x.shape[-1])
-        if x.stride(-1) != 1:
-            x = x.contiguous()
-        if residual is not None:
-            assert residual.shape == x_shape_og
-            residual = residual.reshape(-1, residual.shape[-1])
-            if residual.stride(-1) != 1:
-                residual = residual.contiguous()
-        weight = weight.contiguous()
-        if bias is not None:
-            bias = bias.contiguous()
-        residual_dtype = (
-            residual.dtype
-            if residual is not None
-            else (torch.float32 if residual_in_fp32 else None)
-        )
-        y, mean, rstd, residual_out = _layer_norm_fwd(
-            x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
-        )
-        ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
-        ctx.x_shape_og = x_shape_og
-        ctx.eps = eps
-        ctx.is_rms_norm = is_rms_norm
-        ctx.has_residual = residual is not None
-        ctx.prenorm = prenorm
-        ctx.x_dtype = x.dtype
-        y = y.reshape(x_shape_og)
-        return y if not prenorm else (y, residual_out.reshape(x_shape_og))
-
-    @staticmethod
-    def backward(ctx, dy, *args):
-        x, weight, bias, mean, rstd = ctx.saved_tensors
-        dy = dy.reshape(-1, dy.shape[-1])
-        if dy.stride(-1) != 1:
-            dy = dy.contiguous()
-        assert dy.shape == x.shape
-        if ctx.prenorm:
-            dresidual = args[0]
-            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
-            if dresidual.stride(-1) != 1:
-                dresidual = dresidual.contiguous()
-            assert dresidual.shape == x.shape
-        else:
-            dresidual = None
-        dx, dw, db, dresidual_in = _layer_norm_bwd(
-            dy,
-            x,
-            weight,
-            bias,
-            ctx.eps,
-            mean,
-            rstd,
-            dresidual,
-            ctx.has_residual,
-            ctx.is_rms_norm,
-            x_dtype=ctx.x_dtype,
-        )
-        return (
-            dx.reshape(ctx.x_shape_og),
-            dw,
-            db,
-            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
-            None,
-            None,
-            None,
-            None,
-        )
-
-
-def layer_norm_fn(
-    x,
-    weight,
-    bias,
-    residual=None,
-    eps=1e-6,
-    prenorm=False,
-    residual_in_fp32=False,
-    is_rms_norm=False,
-):
-    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
-
-
-def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
-    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
-
-
-class RMSNorm(torch.nn.Module):
-    def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
-        factory_kwargs = {"device": device, "dtype": dtype}
-        super().__init__()
-        self.eps = eps
-        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
-        self.register_parameter("bias", None)
-        self.reset_parameters()
-
-    def reset_parameters(self):
-        torch.nn.init.ones_(self.weight)
-
-    def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
-        return rms_norm_fn(
-            x,
-            self.weight,
-            self.bias,
-            residual=residual,
-            eps=self.eps,
-            prenorm=prenorm,
-            residual_in_fp32=residual_in_fp32,
-            # is_rms_norm=True,
-        )
-
-
-class LayerNormLinearFn(torch.autograd.Function):
-    @staticmethod
-    @custom_fwd
-    def forward(
-        ctx,
-        x,
-        norm_weight,
-        norm_bias,
-        linear_weight,
-        linear_bias,
-        residual=None,
-        eps=1e-6,
-        prenorm=False,
-        residual_in_fp32=False,
-        is_rms_norm=False,
-    ):
-        x_shape_og = x.shape
-        # reshape input data into 2D tensor
-        x = x.reshape(-1, x.shape[-1])
-        if x.stride(-1) != 1:
-            x = x.contiguous()
-        if residual is not None:
-            assert residual.shape == x_shape_og
-            residual = residual.reshape(-1, residual.shape[-1])
-            if residual.stride(-1) != 1:
-                residual = residual.contiguous()
-        norm_weight = norm_weight.contiguous()
-        if norm_bias is not None:
-            norm_bias = norm_bias.contiguous()
-        residual_dtype = (
-            residual.dtype
-            if residual is not None
-            else (torch.float32 if residual_in_fp32 else None)
-        )
-        y, mean, rstd, residual_out = _layer_norm_fwd(
-            x,
-            norm_weight,
-            norm_bias,
-            eps,
-            residual,
-            out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
-            residual_dtype=residual_dtype,
-            is_rms_norm=is_rms_norm,
-        )
-        y = y.reshape(x_shape_og)
-        dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
-        linear_weight = linear_weight.to(dtype)
-        linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
-        out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
-        # We don't store y, will be recomputed in the backward pass to save memory
-        ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
-        ctx.x_shape_og = x_shape_og
-        ctx.eps = eps
-        ctx.is_rms_norm = is_rms_norm
-        ctx.has_residual = residual is not None
-        ctx.prenorm = prenorm
-        ctx.x_dtype = x.dtype
-        ctx.linear_bias_is_none = linear_bias is None
-        return out if not prenorm else (out, residual_out.reshape(x_shape_og))
-
-    @staticmethod
-    @custom_bwd
-    def backward(ctx, dout, *args):
-        x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
-        dout = dout.reshape(-1, dout.shape[-1])
-        dy = F.linear(dout, linear_weight.t())
-        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
-        if dy.stride(-1) != 1:
-            dy = dy.contiguous()
-        assert dy.shape == x.shape
-        if ctx.prenorm:
-            dresidual = args[0]
-            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
-            if dresidual.stride(-1) != 1:
-                dresidual = dresidual.contiguous()
-            assert dresidual.shape == x.shape
-        else:
-            dresidual = None
-        dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
-            dy,
-            x,
-            norm_weight,
-            norm_bias,
-            ctx.eps,
-            mean,
-            rstd,
-            dresidual,
-            ctx.has_residual,
-            ctx.is_rms_norm,
-            x_dtype=ctx.x_dtype,
-            recompute_output=True,
-        )
-        dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
-        return (
-            dx.reshape(ctx.x_shape_og),
-            dnorm_weight,
-            dnorm_bias,
-            dlinear_weight,
-            dlinear_bias,
-            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
-            None,
-            None,
-            None,
-            None,
-        )
-
-
-def layer_norm_linear_fn(
-    x,
-    norm_weight,
-    norm_bias,
-    linear_weight,
-    linear_bias,
-    residual=None,
-    eps=1e-6,
-    prenorm=False,
-    residual_in_fp32=False,
-    is_rms_norm=False,
-):
-    return LayerNormLinearFn.apply(
-        x,
-        norm_weight,
-        norm_bias,
-        linear_weight,
-        linear_bias,
-        residual,
-        eps,
-        prenorm,
-        residual_in_fp32,
-        is_rms_norm,
-    )
diff --git a/mamba/mamba_ssm/ops/triton/selective_state_update.py b/mamba/mamba_ssm/ops/triton/selective_state_update.py
deleted file mode 100644
index fa95de73f173292914c5f00fbe9426937d00e502..0000000000000000000000000000000000000000
--- a/mamba/mamba_ssm/ops/triton/selective_state_update.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# Copyright (c) 2023, Tri Dao.
-
-"""We want triton==2.1.0 for this
-"""
-
-import math
-import torch
-import torch.nn.functional as F
-
-import triton
-import triton.language as tl
-
-from einops import rearrange, repeat
-
-
-@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
-@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
-@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
-@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
-@triton.jit
-def _selective_scan_update_kernel(
-    # Pointers to matrices
-    state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
-    # Matrix dimensions
-    batch, dim, dstate,
-    # Strides
-    stride_state_batch, stride_state_dim, stride_state_dstate,
-    stride_x_batch, stride_x_dim,
-    stride_dt_batch, stride_dt_dim,
-    stride_dt_bias_dim,
-    stride_A_dim, stride_A_dstate,
-    stride_B_batch, stride_B_dstate,
-    stride_C_batch, stride_C_dstate,
-    stride_D_dim,
-    stride_z_batch, stride_z_dim,
-    stride_out_batch, stride_out_dim,
-    # Meta-parameters
-    DT_SOFTPLUS: tl.constexpr,
-    BLOCK_SIZE_M: tl.constexpr,
-    HAS_DT_BIAS: tl.constexpr,
-    HAS_D: tl.constexpr,
-    HAS_Z: tl.constexpr,
-    BLOCK_SIZE_DSTATE: tl.constexpr,
-):
-    pid_m = tl.program_id(axis=0)
-    pid_b = tl.program_id(axis=1)
-    state_ptr += pid_b * stride_state_batch
-    x_ptr += pid_b * stride_x_batch
-    dt_ptr += pid_b * stride_dt_batch
-    B_ptr += pid_b * stride_B_batch
-    C_ptr += pid_b * stride_C_batch
-    if HAS_Z:
-        z_ptr += pid_b * stride_z_batch
-    out_ptr += pid_b * stride_out_batch
-
-    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
-    offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
-    state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
-    x_ptrs = x_ptr + offs_m * stride_x_dim
-    dt_ptrs = dt_ptr + offs_m * stride_dt_dim
-    if HAS_DT_BIAS:
-        dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
-    A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
-    B_ptrs = B_ptr + offs_n * stride_B_dstate
-    C_ptrs = C_ptr + offs_n * stride_C_dstate
-    if HAS_D:
-        D_ptrs = D_ptr + offs_m * stride_D_dim
-    if HAS_Z:
-        z_ptrs = z_ptr + offs_m * stride_z_dim
-    out_ptrs = out_ptr + offs_m * stride_out_dim
-
-    state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
-    x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
-    dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
-    if HAS_DT_BIAS:
-        dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
-    if DT_SOFTPLUS:
-        dt = tl.log(1.0 + tl.exp(dt))
-    A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
-    dA = tl.exp(A * dt[:, None])
-    B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
-    C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
-    if HAS_D:
-        D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
-    if HAS_Z:
-        z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
-
-    dB = B[None, :] * dt[:, None]
-    state = state * dA + dB * x[:, None]
-    tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
-    out = tl.sum(state * C[None, :], axis=1)
-    if HAS_D:
-        out += x * D
-    if HAS_Z:
-        out *= z * tl.sigmoid(z)
-    tl.store(out_ptrs, out, mask=offs_m < dim)
-
-
-def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
-    """
-    Argument:
-        state: (batch, dim, dstate)
-        x: (batch, dim)
-        dt: (batch, dim)
-        A: (dim, dstate)
-        B: (batch, dstate)
-        C: (batch, dstate)
-        D: (dim,)
-        z: (batch, dim)
-        dt_bias: (dim,)
-    Return:
-        out: (batch, dim)
-    """
-    batch, dim, dstate = state.shape
-    assert x.shape == (batch, dim)
-    assert dt.shape == x.shape
-    assert A.shape == (dim, dstate)
-    assert B.shape == (batch, dstate)
-    assert C.shape == B.shape
-    if D is not None:
-        assert D.shape == (dim,)
-    if z is not None:
-        assert z.shape == x.shape
-    if dt_bias is not None:
-        assert dt_bias.shape == (dim,)
-    out = torch.empty_like(x)
-    grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
-    z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
-    # We don't want autotune since it will overwrite the state
-    # We instead tune by hand.
-    BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
-                               else ((16, 4) if dstate <= 32 else
-                                     ((8, 4) if dstate <= 64 else
-                                      ((4, 4) if dstate <= 128 else
-                                       ((4, 8))))))
-    with torch.cuda.device(x.device.index):
-        _selective_scan_update_kernel[grid](
-            state, x, dt, dt_bias, A, B, C, D, z, out,
-            batch, dim, dstate,
-            state.stride(0), state.stride(1), state.stride(2),
-            x.stride(0), x.stride(1),
-            dt.stride(0), dt.stride(1),
-            dt_bias.stride(0) if dt_bias is not None else 0,
-            A.stride(0), A.stride(1),
-            B.stride(0), B.stride(1),
-            C.stride(0), C.stride(1),
-            D.stride(0) if D is not None else 0,
-            z_strides[0], z_strides[1],
-            out.stride(0), out.stride(1),
-            dt_softplus,
-            BLOCK_SIZE_M,
-            num_warps=num_warps,
-        )
-    return out
-
-
-def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
-    """
-    Argument:
-        state: (batch, dim, dstate)
-        x: (batch, dim)
-        dt: (batch, dim)
-        A: (dim, dstate)
-        B: (batch, dstate)
-        C: (batch, dstate)
-        D: (dim,)
-        z: (batch, dim)
-        dt_bias: (dim,)
-    Return:
-        out: (batch, dim)
-    """
-    batch, dim, dstate = state.shape
-    assert x.shape == (batch, dim)
-    assert dt.shape == x.shape
-    assert A.shape == (dim, dstate)
-    assert B.shape == (batch, dstate)
-    assert C.shape == B.shape
-    if D is not None:
-        assert D.shape == (dim,)
-    if z is not None:
-        assert z.shape == x.shape
-    if dt_bias is not None:
-        assert dt_bias.shape == (dim,)
-        dt = dt + dt_bias
-    dt = F.softplus(dt) if dt_softplus else dt
-    dA = torch.exp(rearrange(dt, "b d -> b d 1") * A)  # (batch, dim, dstate)
-    dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n")  # (batch, dim, dstate)
-    state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1"))  # (batch, dim, dstate
-    out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
-    if D is not None:
-        out += (x * D).to(out.dtype)
-    return (out if z is None else out * F.silu(z)).to(x.dtype)
diff --git a/mamba/mamba_ssm/utils/__init__.py b/mamba/mamba_ssm/utils/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/mamba/mamba_ssm/utils/generation.py b/mamba/mamba_ssm/utils/generation.py
deleted file mode 100644
index 9d766b29ac28a388a7d77b22aa2cb1eda733c0f4..0000000000000000000000000000000000000000
--- a/mamba/mamba_ssm/utils/generation.py
+++ /dev/null
@@ -1,377 +0,0 @@
-# Copyright (c) 2023, Albert Gu, Tri Dao.
-import gc
-import time
-from collections import namedtuple
-from dataclasses import dataclass, field
-from functools import partial
-from typing import Callable, Optional, Sequence, Union
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange, repeat
-from torch import Tensor
-from torch.profiler import ProfilerActivity, profile, record_function
-from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
-
-
-@dataclass
-class InferenceParams:
-    """Inference parameters that are passed to the main model in order
-    to efficienly calculate and store the context during inference."""
-
-    max_seqlen: int
-    max_batch_size: int
-    seqlen_offset: int = 0
-    batch_size_offset: int = 0
-    key_value_memory_dict: dict = field(default_factory=dict)
-    lengths_per_sample: Optional[Tensor] = None
-
-    def reset(self, max_seqlen, max_batch_size):
-        self.max_seqlen = max_seqlen
-        self.max_batch_size = max_batch_size
-        self.seqlen_offset = 0
-        if self.lengths_per_sample is not None:
-            self.lengths_per_sample.zero_()
-
-
-# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
-# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
-def modify_logits_for_top_k_filtering(logits, top_k):
-    """Set the logits for none top-k values to -inf. Done in-place."""
-    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
-    logits.masked_fill_(indices_to_remove, float("-Inf"))
-
-
-# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
-# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
-def modify_logits_for_top_p_filtering(logits, top_p):
-    """Set the logits for none top-p values to -inf. Done in-place."""
-    if top_p <= 0.0 or top_p >= 1.0:
-        return
-    # First sort and calculate cumulative sum of probabilities.
-    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
-    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
-    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
-    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
-    # scatter sorted tensors to original indexing
-    indices_to_remove = sorted_indices_to_remove.scatter(
-        1, sorted_indices, sorted_indices_to_remove
-    )
-    logits.masked_fill_(indices_to_remove, float("-inf"))
-
-
-def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
-    """Sample from top-k logits.
-    Arguments:
-        logits: Tensor of shape (batch_size, vocab_size)
-    """
-    if top_k == 1:  # Short-circuit for greedy decoding
-        return logits.argmax(dim=-1)
-    else:
-        if top_p > 0.0:
-            assert top_p <= 1.0, "top-p should be in (0, 1]."
-        if top_k > 0:
-            top_k = min(top_k, logits.size(-1))  # Safety check
-            logits_top, indices = torch.topk(logits, top_k, dim=-1)
-            if temperature != 1.0:
-                logits_top /= temperature
-            modify_logits_for_top_p_filtering(logits_top, top_p)
-            return indices[
-                torch.arange(indices.shape[0], device=indices.device),
-                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
-            ]
-        else:
-            # Clone so that when we modify for top_p we don't change the original logits
-            logits_top = logits / temperature if temperature != 1.0 else logits.clone()
-            modify_logits_for_top_p_filtering(logits_top, top_p)
-            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
-                dim=-1
-            )
-
-
-@torch.inference_mode()
-def decode(
-    input_ids,
-    model,
-    max_length,
-    top_k=1,
-    top_p=0.0,
-    temperature=1.0,
-    eos_token_id=None,
-    teacher_outputs=None,
-    vocab_size=None,
-    tensor_parallel=1,
-    cg=False,
-    enable_timing=False,
-):
-    """Decoding, either greedy or with top-k or top-p sampling.
-    If top-k = 0, don't limit the number of candidates (pure sampling).
-    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
-    then top-p.
-    We assume that all sequences in the same batch have the same length.
-
-    Arguments:
-        input_ids: (batch, seq_len)
-        max_length: int
-        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
-            logits, the next token is taken from the teacher_outputs. Useful for testing.
-    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
-        sequences: (batch, max_length)
-        scores: tuples of (batch, vocab_size)
-    """
-    batch_size, seqlen_og = input_ids.shape
-    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
-    if cg:
-        if not hasattr(model, "_decoding_cache"):
-            model._decoding_cache = None
-        model._decoding_cache = update_graph_cache(
-            model,
-            model._decoding_cache,
-            batch_size,
-            seqlen_og,
-            max_length,
-            tensor_parallel=tensor_parallel,
-        )
-        inference_params = model._decoding_cache.inference_params
-        inference_params.reset(max_length, batch_size)
-    else:
-        inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
-
-    def get_logits(input_ids, inference_params):
-        decoding = inference_params.seqlen_offset > 0
-        if decoding:
-            position_ids = torch.full(
-                (batch_size, 1),
-                inference_params.seqlen_offset,
-                dtype=torch.long,
-                device=input_ids.device,
-            )
-        else:
-            position_ids = None
-        if not cg or not decoding:
-            logits = model(
-                input_ids,
-                position_ids=position_ids,
-                inference_params=inference_params,
-                num_last_tokens=1,
-            ).logits.squeeze(dim=1)
-        else:
-            logits = model._decoding_cache.run(
-                input_ids, position_ids, inference_params.seqlen_offset
-            ).squeeze(dim=1)
-        return logits[..., :vocab_size] if vocab_size is not None else logits
-
-    def sample_tokens(logits, inference_params):
-        if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
-            token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
-        else:
-            token = teacher_outputs[:, inference_params.seqlen_offset]
-        # return rearrange(token, "b -> b 1")
-        return token.unsqueeze(1)
-
-    def should_stop(current_token, inference_params):
-        if inference_params.seqlen_offset == 0:
-            return False
-        if eos_token_id is not None and (current_token == eos_token_id).all():
-            return True
-        if inference_params.seqlen_offset >= max_length - 1:
-            return True
-        return False
-
-    start = torch.cuda.Event(enable_timing=enable_timing)
-    end = torch.cuda.Event(enable_timing=enable_timing)
-
-    if enable_timing:
-        if tensor_parallel > 1:
-            torch.distributed.barrier()
-        start.record()
-    scores, sequences = [], [input_ids]
-    while not should_stop(sequences[-1], inference_params):
-        scores.append(get_logits(sequences[-1], inference_params))
-        inference_params.seqlen_offset += sequences[-1].shape[1]
-        sequences.append(sample_tokens(scores[-1], inference_params))
-    if enable_timing:
-        end.record()
-        if tensor_parallel > 1:
-            torch.distributed.barrier()
-        torch.cuda.synchronize()
-        print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
-    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
-    return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
-
-
-class GenerationMixin:
-    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
-        raise NotImplementedError
-
-    def generate(
-        self,
-        input_ids,
-        max_length,
-        top_k=1,
-        top_p=0.0,
-        temperature=1.0,
-        return_dict_in_generate=False,
-        output_scores=False,
-        **kwargs,
-    ):
-        output = decode(
-            input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
-        )
-        if not output_scores:
-            output.scores = None
-        return output if return_dict_in_generate else output.sequences
-
-
-def allocate_inference_cache(
-    max_batch_size,
-    max_seqlen,
-    nheads,
-    headdim,
-    layers: Union[int, Sequence],
-    device,
-    dtype=torch.float16,
-):
-    assert dtype in [torch.float16, torch.bfloat16, torch.float32]
-    kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
-    if isinstance(layers, int):
-        layers = range(layers)
-    return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
-
-
-@dataclass
-class DecodingCGCache:
-    max_batch_size: int = 0
-    max_seqlen: int = 0
-    device = None
-    dtype = None
-    callables: dict = field(default_factory=dict)
-    mempool = None
-    inference_params: Optional[InferenceParams] = None
-    run: Optional[Callable] = None
-
-
-@torch.inference_mode()
-def update_graph_cache(
-    model,
-    cache,
-    batch_size,
-    seqlen_og,
-    max_seqlen,
-    decoding_seqlens=(1,),
-    tensor_parallel=1,
-    dtype=None,
-    n_warmups=2,
-):
-    if cache is None:
-        cache = DecodingCGCache()
-    param_example = next(iter(model.parameters()))
-    device = param_example.device
-    if dtype is None:
-        dtype = param_example.dtype
-    if (
-        (device, dtype) != (cache.device, cache.dtype)
-        or batch_size > cache.max_batch_size
-        or max_seqlen > cache.max_seqlen
-    ):  # Invalidate the cache
-        cache.callables = {}
-        cache.mempool = None
-        cache.inference_params = None
-        gc.collect()
-        cache.device, cache.dtype = device, dtype
-        cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
-        if hasattr(model, "allocate_inference_cache"):
-            inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
-        else:
-            headdim = getattr(
-                model.config,
-                "head_dim",
-                model.config.hidden_size // model.config.num_attention_heads,
-            )
-            inf_cache = allocate_inference_cache(
-                batch_size,
-                max_seqlen,
-                model.config.num_attention_heads // tensor_parallel,
-                headdim,
-                model.config.num_hidden_layers,
-                device,
-                dtype,
-            )
-        lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
-        cache.inference_params = InferenceParams(
-            max_seqlen=max_seqlen,
-            max_batch_size=batch_size,
-            seqlen_offset=seqlen_og,
-            key_value_memory_dict=inf_cache,
-            lengths_per_sample=lengths_per_sample,
-        )
-        cache.mempool = torch.cuda.graphs.graph_pool_handle()
-    for decoding_seqlen in decoding_seqlens:
-        if (batch_size, decoding_seqlen) not in cache.callables:
-            cache.callables[batch_size, decoding_seqlen] = capture_graph(
-                model,
-                cache.inference_params,
-                batch_size,
-                max_seqlen,
-                decoding_seqlen=decoding_seqlen,
-                mempool=cache.mempool,
-                n_warmups=n_warmups,
-            )
-
-    def dispatch(input_ids, position_ids, seqlen):
-        batch_size, decoding_seqlen = input_ids.shape[:2]
-        return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
-
-    cache.run = dispatch
-    cache.inference_params.seqlen_offset = 0  # Reset so it's not confusing
-    return cache
-
-
-def capture_graph(
-    model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
-):
-    device = next(iter(model.parameters())).device
-    input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
-    position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
-    seqlen_offset_og = inference_params.seqlen_offset
-    inference_params.seqlen_offset = max_seqlen - decoding_seqlen
-    inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
-
-    # Warmup before capture
-    s = torch.cuda.Stream()
-    s.wait_stream(torch.cuda.current_stream())
-    with torch.cuda.stream(s):
-        for _ in range(n_warmups):
-            logits = model(
-                input_ids,
-                position_ids=position_ids,
-                inference_params=inference_params,
-                num_last_tokens=decoding_seqlen,
-            ).logits
-        s.synchronize()
-        # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
-        # which requires that graph launch and non-captured launch to not overlap (I think,
-        # that's how I interpret the documentation). I'm not sure if this is required.
-        if torch.distributed.is_initialized():
-            torch.distributed.barrier()
-    torch.cuda.current_stream().wait_stream(s)
-    # Captures the graph
-    # To allow capture, automatically sets a side stream as the current stream in the context
-    graph = torch.cuda.CUDAGraph()
-    with torch.cuda.graph(graph, pool=mempool):
-        logits = model(
-            input_ids,
-            position_ids=position_ids,
-            inference_params=inference_params,
-            num_last_tokens=decoding_seqlen,
-        ).logits
-
-    def run(new_input_ids, new_position_ids, seqlen):
-        inference_params.lengths_per_sample[:] = seqlen
-        input_ids.copy_(new_input_ids)
-        position_ids.copy_(new_position_ids)
-        graph.replay()
-        return logits.clone()
-
-    inference_params.seqlen_offset = seqlen_offset_og
-    return run
diff --git a/mamba/mamba_ssm/utils/hf.py b/mamba/mamba_ssm/utils/hf.py
deleted file mode 100644
index 0d7555acddbd260636d1d14d5bd6324f6af0056a..0000000000000000000000000000000000000000
--- a/mamba/mamba_ssm/utils/hf.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import json
-
-import torch
-
-from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
-from transformers.utils.hub import cached_file
-
-
-def load_config_hf(model_name):
-    resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
-    return json.load(open(resolved_archive_file))
-
-
-def load_state_dict_hf(model_name, device=None, dtype=None):
-    # If not fp32, then we don't want to load directly to the GPU
-    mapped_device = "cpu" if dtype not in [torch.float32, None] else device
-    resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
-    return torch.load(resolved_archive_file, map_location=mapped_device)
-    # Convert dtype before moving to GPU to save memory
-    if dtype is not None:
-        state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
-    state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
-    return state_dict
diff --git a/mamba/setup.py b/mamba/setup.py
deleted file mode 100644
index 2ce0ac045f8b2ae07f39f3d045e997ab362ec4c1..0000000000000000000000000000000000000000
--- a/mamba/setup.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# Copyright (c) 2023, Albert Gu, Tri Dao.
-import sys
-import warnings
-import os
-import re
-import ast
-from pathlib import Path
-from packaging.version import parse, Version
-import platform
-import shutil
-
-from setuptools import setup, find_packages
-import subprocess
-
-import urllib.request
-import urllib.error
-from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
-
-import torch
-from torch.utils.cpp_extension import (
-    BuildExtension,
-    CppExtension,
-    CUDAExtension,
-    CUDA_HOME,
-)
-
-
-with open("README.md", "r", encoding="utf-8") as fh:
-    long_description = fh.read()
-
-
-# ninja build does not work unless include_dirs are abs path
-this_dir = os.path.dirname(os.path.abspath(__file__))
-
-PACKAGE_NAME = "mamba_ssm"
-
-BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}"
-
-# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
-# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
-FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
-SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
-# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
-FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
-
-
-def get_platform():
-    """
-    Returns the platform name as used in wheel filenames.
-    """
-    if sys.platform.startswith("linux"):
-        return "linux_x86_64"
-    elif sys.platform == "darwin":
-        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
-        return f"macosx_{mac_version}_x86_64"
-    elif sys.platform == "win32":
-        return "win_amd64"
-    else:
-        raise ValueError("Unsupported platform: {}".format(sys.platform))
-
-
-def get_cuda_bare_metal_version(cuda_dir):
-    raw_output = subprocess.check_output(
-        [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
-    )
-    output = raw_output.split()
-    release_idx = output.index("release") + 1
-    bare_metal_version = parse(output[release_idx].split(",")[0])
-
-    return raw_output, bare_metal_version
-
-
-def check_if_cuda_home_none(global_option: str) -> None:
-    if CUDA_HOME is not None:
-        return
-    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
-    # in that case.
-    warnings.warn(
-        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
-        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
-        "only images whose names contain 'devel' will provide nvcc."
-    )
-
-
-def append_nvcc_threads(nvcc_extra_args):
-    return nvcc_extra_args + ["--threads", "4"]
-
-
-cmdclass = {}
-ext_modules = []
-
-if not SKIP_CUDA_BUILD:
-    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
-    TORCH_MAJOR = int(torch.__version__.split(".")[0])
-    TORCH_MINOR = int(torch.__version__.split(".")[1])
-
-    check_if_cuda_home_none(PACKAGE_NAME)
-    # Check, if CUDA11 is installed for compute capability 8.0
-    cc_flag = []
-    if CUDA_HOME is not None:
-        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
-        if bare_metal_version < Version("11.6"):
-            raise RuntimeError(
-                f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above.  "
-                "Note: make sure nvcc has a supported version by running nvcc -V."
-            )
-
-    cc_flag.append("-gencode")
-    cc_flag.append("arch=compute_70,code=sm_70")
-    cc_flag.append("-gencode")
-    cc_flag.append("arch=compute_80,code=sm_80")
-    if bare_metal_version >= Version("11.8"):
-        cc_flag.append("-gencode")
-        cc_flag.append("arch=compute_90,code=sm_90")
-
-    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
-    # torch._C._GLIBCXX_USE_CXX11_ABI
-    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
-    if FORCE_CXX11_ABI:
-        torch._C._GLIBCXX_USE_CXX11_ABI = True
-
-    ext_modules.append(
-        CUDAExtension(
-            name="selective_scan_cuda",
-            sources=[
-                "csrc/selective_scan/selective_scan.cpp",
-                "csrc/selective_scan/selective_scan_fwd_fp32.cu",
-                "csrc/selective_scan/selective_scan_fwd_fp16.cu",
-                "csrc/selective_scan/selective_scan_fwd_bf16.cu",
-                "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
-                "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
-                "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
-                "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
-                "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
-                "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
-            ],
-            extra_compile_args={
-                "cxx": ["-O3", "-std=c++17"],
-                "nvcc": append_nvcc_threads(
-                    [
-                        "-O3",
-                        "-std=c++17",
-                        "-U__CUDA_NO_HALF_OPERATORS__",
-                        "-U__CUDA_NO_HALF_CONVERSIONS__",
-                        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
-                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
-                        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
-                        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
-                        "--expt-relaxed-constexpr",
-                        "--expt-extended-lambda",
-                        "--use_fast_math",
-                        "--ptxas-options=-v",
-                        "-lineinfo",
-                    ]
-                    + cc_flag
-                ),
-            },
-            include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
-        )
-    )
-
-
-def get_package_version():
-    with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
-        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
-    public_version = ast.literal_eval(version_match.group(1))
-    local_version = os.environ.get("MAMBA_LOCAL_VERSION")
-    if local_version:
-        return f"{public_version}+{local_version}"
-    else:
-        return str(public_version)
-
-
-def get_wheel_url():
-    # Determine the version numbers that will be used to determine the correct wheel
-    # We're using the CUDA version used to build torch, not the one currently installed
-    # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
-    torch_cuda_version = parse(torch.version.cuda)
-    torch_version_raw = parse(torch.__version__)
-    # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
-    # to save CI time. Minor versions should be compatible.
-    torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
-    python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
-    platform_name = get_platform()
-    mamba_ssm_version = get_package_version()
-    # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
-    cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
-    torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
-    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
-
-    # Determine wheel URL based on CUDA version, torch version, python version and OS
-    wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
-    wheel_url = BASE_WHEEL_URL.format(
-        tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename
-    )
-    return wheel_url, wheel_filename
-
-
-class CachedWheelsCommand(_bdist_wheel):
-    """
-    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
-    find an existing wheel (which is currently the case for all installs). We use
-    the environment parameters to detect whether there is already a pre-built version of a compatible
-    wheel available and short-circuits the standard full build pipeline.
-    """
-
-    def run(self):
-        if FORCE_BUILD:
-            return super().run()
-
-        wheel_url, wheel_filename = get_wheel_url()
-        print("Guessing wheel URL: ", wheel_url)
-        try:
-            urllib.request.urlretrieve(wheel_url, wheel_filename)
-
-            # Make the archive
-            # Lifted from the root wheel processing command
-            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
-            if not os.path.exists(self.dist_dir):
-                os.makedirs(self.dist_dir)
-
-            impl_tag, abi_tag, plat_tag = self.get_tag()
-            archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
-
-            wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
-            print("Raw wheel path", wheel_path)
-            shutil.move(wheel_filename, wheel_path)
-        except urllib.error.HTTPError:
-            print("Precompiled wheel not found. Building from source...")
-            # If the wheel could not be downloaded, build from source
-            super().run()
-
-
-setup(
-    name=PACKAGE_NAME,
-    version=get_package_version(),
-    packages=find_packages(
-        exclude=(
-            "build",
-            "csrc",
-            "include",
-            "tests",
-            "dist",
-            "docs",
-            "benchmarks",
-            "mamba_ssm.egg-info",
-        )
-    ),
-    author="Tri Dao, Albert Gu",
-    author_email="tri@tridao.me, agu@cs.cmu.edu",
-    description="Mamba state-space model",
-    long_description=long_description,
-    long_description_content_type="text/markdown",
-    url="https://github.com/state-spaces/mamba",
-    classifiers=[
-        "Programming Language :: Python :: 3",
-        "License :: OSI Approved :: BSD License",
-        "Operating System :: Unix",
-    ],
-    ext_modules=ext_modules,
-    cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
-    if ext_modules
-    else {
-        "bdist_wheel": CachedWheelsCommand,
-    },
-    python_requires=">=3.7",
-    install_requires=[
-        "torch",
-        "packaging",
-        "ninja",
-        "einops",
-        "triton",
-        "transformers",
-        "causal_conv1d",
-    ],
-)
diff --git a/mamba/test_mamba_module.py b/mamba/test_mamba_module.py
deleted file mode 100644
index 64710e92f7ec4fc0fe88821550e4ecf902a22bfe..0000000000000000000000000000000000000000
--- a/mamba/test_mamba_module.py
+++ /dev/null
@@ -1,15 +0,0 @@
-import torch
-from mamba_ssm import Mamba
-
-batch, length, dim = 2, 64, 768
-x = torch.randn(batch, length, dim).to("cuda")
-model = Mamba(
-    # This module uses roughly 3 * expand * d_model^2 parameters
-    d_model=dim, # Model dimension d_model
-    d_state=16,  # SSM state expansion factor # 64
-    d_conv=4,    # Local convolution width
-    expand=2,    # Block expansion factor
-    use_fast_path=False,
-).to("cuda")
-y = model(x)
-assert y.shape == x.shape
diff --git a/mamba/tests/ops/test_selective_scan.py b/mamba/tests/ops/test_selective_scan.py
deleted file mode 100644
index 26b34a37560f08ced653a1d9320a14f3d3f9ebd3..0000000000000000000000000000000000000000
--- a/mamba/tests/ops/test_selective_scan.py
+++ /dev/null
@@ -1,423 +0,0 @@
-# Copyright (C) 2023, Tri Dao.
-
-import math
-
-import torch
-import torch.nn.functional as F
-from torch.autograd import gradcheck
-import pytest
-
-from einops import rearrange
-
-from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
-from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref
-from mamba_ssm.ops.selective_scan_interface import bimamba_inner_fn, bimamba_inner_ref
-
-
-# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
-@pytest.mark.parametrize('wtype', [torch.float32])
-# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
-@pytest.mark.parametrize('itype', [torch.float32])
-# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
-@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
-# @pytest.mark.parametrize('seqlen', [128])
-# @pytest.mark.parametrize("return_last_state", [False, True])
-@pytest.mark.parametrize("return_last_state", [True])
-# @pytest.mark.parametrize('has_delta_bias', [False, True])
-@pytest.mark.parametrize('has_delta_bias', [True])
-# @pytest.mark.parametrize('delta_softplus', [False, True])
-@pytest.mark.parametrize('delta_softplus', [True])
-# @pytest.mark.parametrize('has_z', [False, True])
-@pytest.mark.parametrize('has_z', [True])
-# @pytest.mark.parametrize('has_D', [False, True])
-@pytest.mark.parametrize('has_D', [True])
-@pytest.mark.parametrize("varBC_groups", [1, 2])
-# @pytest.mark.parametrize("varBC_groups", [1])
-# @pytest.mark.parametrize("is_variable_C", [False, True])
-@pytest.mark.parametrize("is_variable_C", [True])
-# @pytest.mark.parametrize("is_variable_B", [False, True])
-@pytest.mark.parametrize("is_variable_B", [True])
-def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,
-                        delta_softplus, return_last_state, seqlen, itype, wtype):
-    if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
-        pytest.skip()  # This config is not applicable
-    device = 'cuda'
-    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
-    if itype == torch.bfloat16:
-        rtol, atol = 3e-2, 5e-2
-    rtolw, atolw = (1e-3, 1e-3)
-    if has_z:  # If we have z, the errors on the weights seem higher
-        rtolw = max(rtolw, rtol)
-        atolw = max(atolw, atol)
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 2
-    dim = 4
-    dstate = 8
-    is_complex = wtype == torch.complex64
-    A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
-    if not is_variable_B:
-        B_shape = (dim, dstate)
-    elif varBC_groups == 1:
-        B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
-    else:
-        B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
-    B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype,
-                    requires_grad=True)
-    if not is_variable_C:
-        C_shape = (dim, dstate)
-    elif varBC_groups == 1:
-        C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
-    else:
-        C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
-    C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype,
-                    requires_grad=True)
-    if has_D:
-        D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    else:
-        D = None
-    if has_z:
-        z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
-    else:
-        z = None
-    if has_delta_bias:
-        delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
-    else:
-        delta_bias = None
-    u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
-    delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_()
-    A_ref = A.detach().clone().requires_grad_()
-    B_ref = B.detach().clone().requires_grad_()
-    C_ref = C.detach().clone().requires_grad_()
-    D_ref = D.detach().clone().requires_grad_() if D is not None else None
-    z_ref = z.detach().clone().requires_grad_() if z is not None else None
-    u_ref = u.detach().clone().requires_grad_()
-    delta_ref = delta.detach().clone().requires_grad_()
-    delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
-    out, *rest = selective_scan_fn(
-        u, delta, A, B, C, D, z=z,
-        delta_bias=delta_bias, delta_softplus=delta_softplus,
-        return_last_state=return_last_state
-    )
-    if return_last_state:
-        state = rest[0]
-    out_ref, *rest = selective_scan_ref(
-        u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
-        delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
-        return_last_state=return_last_state
-    )
-    if return_last_state:
-        state_ref = rest[0]
-    # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
-    # dt_u = delta * u
-
-    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
-    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
-    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
-    if return_last_state:
-        print(f'State max diff: {(state - state_ref).abs().max().item()}')
-        assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
-
-    g = torch.randn_like(out)
-    out_ref.backward(g)
-    out.backward(g)
-
-    print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}')
-    print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}')
-    print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
-    print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
-    print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
-    if has_D:
-        print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
-    if has_z:
-        print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}')
-    if has_delta_bias:
-        print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
-
-    assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
-    assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
-    assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
-    assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
-                          atol=atolw if not is_variable_B else atol)
-    assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
-                          atol=atolw if not is_variable_C else atol)
-    if has_D:
-        assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
-    if has_z:
-        assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw)
-    if has_delta_bias:
-        assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
-
-
-@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
-# @pytest.mark.parametrize('wtype', [torch.complex64])
-# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
-@pytest.mark.parametrize('itype', [torch.float32])
-# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
-@pytest.mark.parametrize('seqlen', [128])
-@pytest.mark.parametrize("is_variable_C", [False, True])
-# @pytest.mark.parametrize("is_variable_C", [False])
-@pytest.mark.parametrize("is_variable_B", [False, True])
-# @pytest.mark.parametrize("is_variable_B", [True])
-def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
-    device = 'cuda'
-    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
-    if itype == torch.bfloat16:
-        rtol, atol = 3e-2, 5e-2
-    rtolw, atolw = (1e-3, 1e-3)
-    # If we have z, the errors on the weights seem higher
-    rtolw = max(rtolw, rtol)
-    atolw = max(atolw, atol)
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 2
-    dim = 768
-    dstate = 8
-    dt_rank = 48
-    is_complex = wtype == torch.complex64
-    xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
-    conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
-    conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
-                                * (1 if not is_complex else 2),
-                                dim, device=device, dtype=itype, requires_grad=True)
-    delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
-    out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
-    out_proj_bias = None
-    A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
-    B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
-         if not is_variable_B else None)
-    C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
-         if not is_variable_C else None)
-    D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
-    B_proj_bias = None
-    C_proj_bias = None
-    xz_ref = xz.detach().clone().requires_grad_()
-    conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
-    conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
-    x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
-    delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
-    out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
-    out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
-                         if out_proj_bias is not None else None)
-    A_ref = A.detach().clone().requires_grad_()
-    B_ref = B.detach().clone().requires_grad_() if B is not None else None
-    C_ref = C.detach().clone().requires_grad_() if C is not None else None
-    D_ref = D.detach().clone().requires_grad_()
-    delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
-    out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-                         out_proj_weight, out_proj_bias,
-                         A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
-    out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
-                              delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
-                              A_ref, B_ref, C_ref, D_ref,
-                              delta_bias=delta_bias_ref, delta_softplus=True)
-    # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
-    # dt_u = delta * u
-    print("mamba_inner_fn")
-    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
-    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
-    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
-
-    g = torch.randn_like(out)
-    out_ref.backward(g)
-    out.backward(g)
-
-    print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}')
-    print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
-    if not is_variable_B:
-        print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
-    if not is_variable_C:
-        print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
-    print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
-    print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
-    print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}')
-    print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}')
-    print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}')
-    print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}')
-    print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}')
-
-    # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
-    # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
-    # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
-    # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
-    #                       atol=atolw if not is_variable_B else atol)
-    # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
-    #                       atol=atolw if not is_variable_C else atol)
-    # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
-    # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
-
-
-# test_mamba_inner_fn(False, False, 128, torch.float32, torch.float32)
-
-
-@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
-# @pytest.mark.parametrize('wtype', [torch.complex64])
-# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
-@pytest.mark.parametrize('itype', [torch.float32])
-# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
-@pytest.mark.parametrize('seqlen', [128])
-@pytest.mark.parametrize("is_variable_C", [False, True])
-# @pytest.mark.parametrize("is_variable_C", [False])
-@pytest.mark.parametrize("is_variable_B", [False, True])
-# @pytest.mark.parametrize("is_variable_B", [True])
-def test_bimamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
-    device = 'cuda'
-    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
-    if itype == torch.bfloat16:
-        rtol, atol = 3e-2, 5e-2
-    rtolw, atolw = (1e-3, 1e-3)
-    # If we have z, the errors on the weights seem higher
-    rtolw = max(rtolw, rtol)
-    atolw = max(atolw, atol)
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 2
-    dim = 768
-    dstate = 8
-    dt_rank = 48
-    is_complex = wtype == torch.complex64
-    xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
-    conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
-    conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
-                                * (1 if not is_complex else 2),
-                                dim, device=device, dtype=itype, requires_grad=True)
-    delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
-    out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
-    out_proj_bias = None
-    A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
-    A_b = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
-    B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
-         if not is_variable_B else None)
-    C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
-         if not is_variable_C else None)
-    D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
-    B_proj_bias = None
-    C_proj_bias = None
-    xz_ref = xz.detach().clone().requires_grad_()
-    conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
-    conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
-    x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
-    delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
-    out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
-    out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
-                         if out_proj_bias is not None else None)
-    A_ref = A.detach().clone().requires_grad_()
-    A_b_ref = A_b.detach().clone().requires_grad_()
-    B_ref = B.detach().clone().requires_grad_() if B is not None else None
-    C_ref = C.detach().clone().requires_grad_() if C is not None else None
-    D_ref = D.detach().clone().requires_grad_()
-    delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
-    out = bimamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
-                         out_proj_weight, out_proj_bias,
-                         A, A_b, B, C, D, delta_bias=delta_bias, delta_softplus=True)
-    out_ref = bimamba_inner_fn(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
-                              delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
-                              A_ref, A_b_ref, B_ref, C_ref, D_ref,
-                              delta_bias=delta_bias_ref, delta_softplus=True)
-    # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
-    # dt_u = delta * u
-    print("bimamba_inner_fn")
-    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
-    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
-    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
-
-    g = torch.randn_like(out)
-    out_ref.backward(g)
-    out.backward(g)
-
-    print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}')
-    print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
-    print(f'dA_b max diff: {(A_b.grad - A_b_ref.grad).abs().max().item()}')
-    if not is_variable_B:
-        print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
-    if not is_variable_C:
-        print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
-    print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
-    print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
-    print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}')
-    print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}')
-    print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}')
-    print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}')
-    print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}')
-
-@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
-# @pytest.mark.parametrize('wtype', [torch.complex64])
-# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
-@pytest.mark.parametrize('itype', [torch.float32])
-# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
-@pytest.mark.parametrize('seqlen', [128])
-@pytest.mark.parametrize("is_variable_C", [False, True])
-# @pytest.mark.parametrize("is_variable_C", [False])
-@pytest.mark.parametrize("is_variable_B", [False, True])
-# @pytest.mark.parametrize("is_variable_B", [True])
-def test_bimamba_inner_fn_grad_check(is_variable_B, is_variable_C, seqlen, itype, wtype):
-    device = 'cuda'
-    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
-    if itype == torch.bfloat16:
-        rtol, atol = 3e-2, 5e-2
-    rtolw, atolw = (1e-3, 1e-3)
-    # If we have z, the errors on the weights seem higher
-    rtolw = max(rtolw, rtol)
-    atolw = max(atolw, atol)
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 2 // 2
-    dim = 768 // 8
-    dstate = 8 // 8
-    dt_rank = 48 // 8
-    is_complex = wtype == torch.complex64
-    xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
-    conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
-    conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
-                                * (1 if not is_complex else 2),
-                                dim, device=device, dtype=itype, requires_grad=True)
-    delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
-    out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
-    out_proj_bias = None
-    A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
-    A_b = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
-    B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
-         if not is_variable_B else None)
-    C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
-         if not is_variable_C else None)
-    D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
-    delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
-    B_proj_bias = None
-    C_proj_bias = None
-    xz_ref = xz.detach().clone().requires_grad_()
-    conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
-    conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
-    x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
-    delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
-    out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
-    out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
-                         if out_proj_bias is not None else None)
-    A_ref = A.detach().clone().requires_grad_()
-    A_b_ref = A_b.detach().clone().requires_grad_()
-    B_ref = B.detach().clone().requires_grad_() if B is not None else None
-    C_ref = C.detach().clone().requires_grad_() if C is not None else None
-    D_ref = D.detach().clone().requires_grad_()
-    delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
-
-    # func = bimamba_inner_fn
-    # func = mamba_inner_fn
-    func = mamba_inner_ref
-
-    # gradok = gradcheck(func, (xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias, A, A_b, B, C, D, delta_bias, None, None, True))
-    gradok = gradcheck(func, (xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, None, None, True), eps=1e-6, atol=1e-4, nondet_tol=1.)
-    print(f'* {gradok} check_gradient_numerical bimamba_inner_fn')
-
-
-
-# test_bimamba_inner_fn(True, True, 128, torch.float32, torch.float32)
-# test_mamba_inner_fn(True, True, 128, torch.float32, torch.float32)
-test_bimamba_inner_fn_grad_check(True, True, 128, torch.float32, torch.float32)
-
-# input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
-# test = gradcheck(torch.nn.functional.linear, input, eps=1e-6, atol=1e-4)
-# print(test)
\ No newline at end of file
diff --git a/mamba/tests/ops/triton/test_selective_state_update.py b/mamba/tests/ops/triton/test_selective_state_update.py
deleted file mode 100644
index 70a8d79d9cad3e4d33897478caf178bd96d0ae5a..0000000000000000000000000000000000000000
--- a/mamba/tests/ops/triton/test_selective_state_update.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright (C) 2023, Tri Dao.
-
-import math
-
-import torch
-import torch.nn.functional as F
-import pytest
-
-from einops import rearrange
-
-from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
-
-
-@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
-# @pytest.mark.parametrize('itype', [torch.float16])
-@pytest.mark.parametrize("has_z", [False, True])
-# @pytest.mark.parametrize('has_z', [True])
-@pytest.mark.parametrize("dstate", [16, 32, 64])
-# @pytest.mark.parametrize("dstate", [16])
-@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
-# @pytest.mark.parametrize("dim", [2048])
-def test_causal_conv1d_update(dim, dstate, has_z, itype):
-    device = "cuda"
-    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
-    if itype == torch.bfloat16:
-        rtol, atol = 1e-2, 5e-2
-    # set seed
-    torch.random.manual_seed(0)
-    batch_size = 2
-    state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
-    x = torch.randn(batch_size, dim, device=device, dtype=itype)
-    dt = torch.randn(batch_size, dim, device=device, dtype=itype)
-    dt_bias = torch.rand(dim, device=device) - 4.0
-    A = -torch.rand(dim, dstate, device=device) - 1.0
-    B = torch.randn(batch_size, dstate, device=device)
-    C = torch.randn(batch_size, dstate, device=device)
-    D = torch.randn(dim, device=device)
-    if has_z:
-        z = torch.randn_like(x)
-    else:
-        z = None
-    state_ref = state.detach().clone()
-    out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
-    out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
-
-    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
-    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
-    assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
-    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
diff --git a/mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl b/mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl
new file mode 100644
index 0000000000000000000000000000000000000000..644884621c7f4f5c570309baa5e76da306e5afb2
--- /dev/null
+++ b/mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96c19c64be06f0bb5607c6a2406fdf7289a0d83037ea005ec56a2e4ae5d40ec7
+size 145927203
diff --git a/requirements.txt b/requirements.txt
index 4ab553c59fb25f9037863b9ff8728087e2cc7685..8cb776f0a752ee9c9cb1a00d5bab959a8f778bfa 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,5 +8,6 @@ setuptools
 timm
 transformers
 wheel
-# torch --index-url https://download.pytorch.org/whl/cu118
-# torchvision --index-url https://download.pytorch.org/whl/cu118
\ No newline at end of file
+--extra-index-url https://download.pytorch.org/whl/cu118
+torch==2.1.1
+torchvision==0.16.1