diff --git a/app.py b/app.py index 974a442a7cab0fd17cad9a20e6e6788ba68b9e86..61f771877272c3a46f7fad71564a92a3787be668 100644 --- a/app.py +++ b/app.py @@ -1,15 +1,13 @@ -import os -# import spaces +import shlex +import subprocess +import spaces import torch -os.system("nvidia-smi") -print("TORCH_CUDA", torch.cuda.is_available()) - - # install packages for mamba def install(): print("Install personal packages", flush=True) - os.system("bash install.sh") + subprocess.run(shlex.split("pip install causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl")) + subprocess.run(shlex.split("pip install mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl")) install() @@ -25,7 +23,7 @@ from videomamba_video import videomamba_tiny from kinetics_class_index import kinetics_classnames from imagenet_class_index import imagenet_classnames from transforms import ( - GroupNormalize, GroupScale, GroupCenterCrop, + GroupNormalize, GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor ) @@ -38,7 +36,7 @@ from huggingface_hub import hf_hub_download device = "cuda" model_video_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_k400_f16_res224.pth") model_image_path = hf_hub_download(repo_id="OpenGVLab/VideoMamba", filename="videomamba_t16_in1k_res224.pth") -# Pick a pretrained model +# Pick a pretrained model model_video = videomamba_tiny(num_classes=400, num_frames=16) video_sd = torch.load(model_video_path, map_location='cpu') model_video.load_state_dict(video_sd) @@ -55,7 +53,7 @@ for k, v in kinetics_classnames.items(): kinetics_id_to_classname[k] = v imagenet_id_to_classname = {} for k, v in imagenet_classnames.items(): - imagenet_id_to_classname[k] = v[1] + imagenet_id_to_classname[k] = v[1] def get_index(num_frames, num_segments=8): @@ -83,7 +81,7 @@ def load_video(video_path): GroupCenterCrop(crop_size), Stack(), ToTorchFormatTensor(), - GroupNormalize(input_mean, input_std) + GroupNormalize(input_mean, input_std) ]) images_group = list() @@ -92,28 +90,24 @@ def load_video(video_path): images_group.append(img) torch_imgs = transform(images_group) return torch_imgs - -# @spaces.GPU + +@spaces.GPU def inference_video(video): vid = load_video(video) - + # The model expects inputs of shape: B x C x H x W TC, H, W = vid.shape inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4) - + with torch.no_grad(): prediction = model_video(inputs.to(device)) prediction = F.softmax(prediction, dim=1).flatten() return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)} - - -def set_example_video(example: list) -> dict: - return gr.Video.update(value=example[0]) -# @spaces.GPU +@spaces.GPU def inference_image(img): image = img image_transform = T.Compose( @@ -125,10 +119,10 @@ def inference_image(img): ] ) image = image_transform(image) - + # The model expects inputs of shape: B x C x H x W image = image.unsqueeze(0) - + with torch.no_grad(): prediction = model_image(image.to(device)) prediction = F.softmax(prediction, dim=1).flatten() @@ -136,10 +130,6 @@ def inference_image(img): return {imagenet_id_to_classname[str(i)]: float(prediction[i]) for i in range(1000)} -def set_example_image(example: list) -> dict: - return gr.Image.update(value=example[0]) - - demo = gr.Blocks() with demo: gr.Markdown( @@ -154,26 +144,26 @@ with demo: with gr.Row(): with gr.Column(): with gr.Row(): - input_video = gr.Video(label='Input Video').style(height=360) + input_video = gr.Video(label='Input Video', height=360) with gr.Row(): submit_video_button = gr.Button('Submit') with gr.Column(): label_video = gr.Label(num_top_classes=5) with gr.Row(): - example_videos = gr.Dataset(components=[input_video], samples=[['./videos/hitting_baseball.mp4'], ['./videos/hoverboarding.mp4'], ['./videos/yoga.mp4']]) - + gr.Examples(examples=['./videos/hitting_baseball.mp4', './videos/hoverboarding.mp4', './videos/yoga.mp4'], inputs=input_video, outputs=label_video, fn=inference_video, cache_examples=True) + with gr.Tab("Image"): # with gr.Box(): with gr.Row(): with gr.Column(): with gr.Row(): - input_image = gr.Image(label='Input Image', type='pil').style(height=360) + input_image = gr.Image(label='Input Image', type='pil', height=360) with gr.Row(): submit_image_button = gr.Button('Submit') with gr.Column(): label_image = gr.Label(num_top_classes=5) with gr.Row(): - example_images = gr.Dataset(components=[input_image], samples=[['./images/cat.png'], ['./images/dog.png'], ['./images/panda.png']]) + gr.Examples(examples=['./images/cat.png', './images/dog.png', './images/panda.png'], inputs=input_image, outputs=label_image, fn=inference_image, cache_examples=True) gr.Markdown( """ @@ -182,9 +172,6 @@ with demo: ) submit_video_button.click(fn=inference_video, inputs=input_video, outputs=label_video) - example_videos.click(fn=set_example_video, inputs=example_videos, outputs=example_videos._components) submit_image_button.click(fn=inference_image, inputs=input_image, outputs=label_image) - example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images._components) -demo.launch(enable_queue=True) -# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True) \ No newline at end of file +demo.queue(max_size=20).launch() diff --git a/causal-conv1d/AUTHORS b/causal-conv1d/AUTHORS deleted file mode 100644 index 88193855314bb723ced1860384e417954f559700..0000000000000000000000000000000000000000 --- a/causal-conv1d/AUTHORS +++ /dev/null @@ -1 +0,0 @@ -Tri Dao, tri@tridao.me diff --git a/causal-conv1d/LICENSE b/causal-conv1d/LICENSE deleted file mode 100644 index 5860e4b33f3d9d85fc636137c559331d51783a5b..0000000000000000000000000000000000000000 --- a/causal-conv1d/LICENSE +++ /dev/null @@ -1,29 +0,0 @@ -BSD 3-Clause License - -Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/causal-conv1d/README.md b/causal-conv1d/README.md deleted file mode 100644 index 4e905425a650d77c5c4854e4c4a261778c4d2690..0000000000000000000000000000000000000000 --- a/causal-conv1d/README.md +++ /dev/null @@ -1 +0,0 @@ -# Causal depthwise conv1d in CUDA with a PyTorch interface diff --git a/causal-conv1d/causal_conv1d/__init__.py b/causal-conv1d/causal_conv1d/__init__.py deleted file mode 100644 index cc4d610a1e557cabd723fb6e33438f03c5c4bf66..0000000000000000000000000000000000000000 --- a/causal-conv1d/causal_conv1d/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -__version__ = "1.0.0" - -from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update diff --git a/causal-conv1d/causal_conv1d/causal_conv1d_interface.py b/causal-conv1d/causal_conv1d/causal_conv1d_interface.py deleted file mode 100644 index f66143c39e767572ca12112811a384239b8beb63..0000000000000000000000000000000000000000 --- a/causal-conv1d/causal_conv1d/causal_conv1d_interface.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -import torch -import torch.nn.functional as F - - -import causal_conv1d_cuda - - -class CausalConv1dFn(torch.autograd.Function): - @staticmethod - def forward(ctx, x, weight, bias=None, activation=None): - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - ctx.save_for_backward(x, weight, bias) - ctx.activation = activation in ["silu", "swish"] - out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation) - return out - - @staticmethod - def backward(ctx, dout): - x, weight, bias = ctx.saved_tensors - if dout.stride(2) != 1 and dout.stride(1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - # Here we just pass in None and dx will be allocated in the C++ code. - dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( - x, weight, bias, dout, None, ctx.activation - ) - return dx, dweight, dbias if bias is not None else None, None - - -def causal_conv1d_fn(x, weight, bias=None, activation=None): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - activation: either None or "silu" or "swish" - - out: (batch, dim, seqlen) - """ - return CausalConv1dFn.apply(x, weight, bias, activation) - - -def causal_conv1d_ref(x, weight, bias=None, activation=None): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) - out = out[..., :seqlen] - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) - - -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): - """ - x: (batch, dim) - conv_state: (batch, dim, width) - weight: (dim, width) - bias: (dim,) - - out: (batch, dim) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) - - -def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): - """ - x: (batch, dim) - conv_state: (batch, dim, width) - weight: (dim, width) - bias: (dim,) - - out: (batch, dim) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - batch, dim = x.shape - width = weight.shape[1] - assert conv_state.shape == (batch, dim, width) - assert weight.shape == (dim, width) - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - out = torch.sum(conv_state * weight, dim=-1) # (B D) - if bias is not None: - out += bias - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/causal-conv1d/csrc/causal_conv1d.cpp b/causal-conv1d/csrc/causal_conv1d.cpp deleted file mode 100644 index 1c80516ac8599d4d80910a1d4d85c4c435cf1e4f..0000000000000000000000000000000000000000 --- a/causal-conv1d/csrc/causal_conv1d.cpp +++ /dev/null @@ -1,333 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include <ATen/cuda/CUDAContext.h> -#include <c10/cuda/CUDAGuard.h> -#include <torch/extension.h> -#include <vector> - -#include "causal_conv1d.h" - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -template<typename input_t, typename weight_t> -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template <typename input_t, typename weight_t> -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -template<typename input_t, typename weight_t> -void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -template<typename input_t, typename weight_t> -void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); - -template<typename input_t, typename weight_t> -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -void set_conv_params_fwd(ConvParamsBase ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t width, - // device pointers - const at::Tensor x, - const at::Tensor weight, - const at::Tensor out, - void* bias_ptr, - bool silu_activation) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.width = width; - - params.silu_activation = silu_activation; - - // Set the pointers and strides. - params.x_ptr = x.data_ptr(); - params.weight_ptr = weight.data_ptr(); - params.bias_ptr = bias_ptr; - params.out_ptr = out.data_ptr(); - // All stride are in elements, not bytes. - params.x_batch_stride = x.stride(0); - params.x_c_stride = x.stride(1); - params.x_l_stride = x.stride(-1); - params.weight_c_stride = weight.stride(0); - params.weight_width_stride = weight.stride(1); - params.out_batch_stride = out.stride(0); - params.out_c_stride = out.stride(1); - params.out_l_stride = out.stride(-1); -} - - -void set_conv_params_bwd(ConvParamsBwd ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t width, - // device pointers - const at::Tensor x, - const at::Tensor weight, - void* bias_ptr, - const at::Tensor dout, - const at::Tensor dx, - const at::Tensor dweight, - void* dbias_ptr, - bool silu_activation) { - // Pass in "dout" instead of "out", we're not gonna use "out" at all. - set_conv_params_fwd(params, batch, dim, seqlen, width, - x, weight, dout, bias_ptr, silu_activation); - - // Set the pointers and strides. - params.dout_ptr = dout.data_ptr(); - params.dx_ptr = dx.data_ptr(); - params.dweight_ptr = dweight.data_ptr(); - params.dbias_ptr = dbias_ptr; - // All stride are in elements, not bytes. - params.dout_batch_stride = dout.stride(0); - params.dout_c_stride = dout.stride(1); - params.dout_l_stride = dout.stride(2); - params.dweight_c_stride = dweight.stride(0); - params.dweight_width_stride = dweight.stride(1); - params.dx_batch_stride = dx.stride(0); - params.dx_c_stride = dx.stride(1); - params.dx_l_stride = dx.stride(2); -} - -at::Tensor -causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, - const c10::optional<at::Tensor> &bias_, - bool silu_activation) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int width = weight.size(-1); - - CHECK_SHAPE(x, batch_size, dim, seqlen); - CHECK_SHAPE(weight, dim, width); - - TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); - const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; - - if (is_channel_last) { - TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); - } - TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - at::Tensor out = torch::empty_like(x); - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_.has_value() ? bias_.value().data_ptr() : nullptr, - silu_activation); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { - if (!is_channel_last) { - causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream); - } else { - causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream); - } - }); - }); - return out; -} - -std::vector<at::Tensor> -causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight, - const c10::optional<at::Tensor> &bias_, - at::Tensor &dout, - c10::optional<at::Tensor> &dx_, - bool silu_activation) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - TORCH_CHECK(dout.is_cuda()); - - const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int width = weight.size(-1); - - TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - CHECK_SHAPE(x, batch_size, dim, seqlen); - CHECK_SHAPE(weight, dim, width); - CHECK_SHAPE(dout, batch_size, dim, seqlen); - - TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); - const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; - if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); } - if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); } - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - at::Tensor dx; - if (dx_.has_value()) { - dx = dx_.value(); - TORCH_CHECK(dx.scalar_type() == input_type); - TORCH_CHECK(dx.is_cuda()); - CHECK_SHAPE(dx, batch_size, dim, seqlen); - if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); } - if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); } - } else { - dx = torch::empty_like(x); - } - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; - - at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat)); - at::Tensor dbias; - if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); } - - ConvParamsBwd params; - set_conv_params_bwd(params, batch_size, dim, seqlen, width, - x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr, - dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr, - silu_activation); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] { - if (!is_channel_last) { - causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream); - } else { - causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream); - } - }); - }); - return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias}; -} - -at::Tensor -causal_conv1d_update(const at::Tensor &x, - const at::Tensor &conv_state, - const at::Tensor &weight, - const c10::optional<at::Tensor> &bias_, - bool silu_activation) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); - TORCH_CHECK(conv_state.scalar_type() == input_type); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(conv_state.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int width = weight.size(-1); - - CHECK_SHAPE(x, batch_size, dim); - CHECK_SHAPE(conv_state, batch_size, dim, width); - CHECK_SHAPE(weight, dim, width); - - TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - at::Tensor out = torch::empty_like(x); - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, - bias_.has_value() ? bias_.value().data_ptr() : nullptr, - silu_activation); - params.conv_state_ptr = conv_state.data_ptr(); - // All stride are in elements, not bytes. - params.conv_state_batch_stride = conv_state.stride(0); - params.conv_state_c_stride = conv_state.stride(1); - params.conv_state_l_stride = conv_state.stride(2); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { - causal_conv1d_update_cuda<input_t, weight_t>(params, stream); - }); - }); - return out; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward"); - m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward"); - m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update"); -} diff --git a/causal-conv1d/csrc/causal_conv1d.h b/causal-conv1d/csrc/causal_conv1d.h deleted file mode 100644 index 844ed92cfc91a881e58fccfca001a13ebcc434cc..0000000000000000000000000000000000000000 --- a/causal-conv1d/csrc/causal_conv1d.h +++ /dev/null @@ -1,53 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ConvParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, width; - bool silu_activation; - - index_t x_batch_stride; - index_t x_c_stride; - index_t x_l_stride; - index_t weight_c_stride; - index_t weight_width_stride; - index_t out_batch_stride; - index_t out_c_stride; - index_t out_l_stride; - - index_t conv_state_batch_stride; - index_t conv_state_c_stride; - index_t conv_state_l_stride; - - // Common data pointers. - void *__restrict__ x_ptr; - void *__restrict__ weight_ptr; - void *__restrict__ bias_ptr; - void *__restrict__ out_ptr; - - void *__restrict__ conv_state_ptr; -}; - -struct ConvParamsBwd: public ConvParamsBase { - index_t dx_batch_stride; - index_t dx_c_stride; - index_t dx_l_stride; - index_t dweight_c_stride; - index_t dweight_width_stride; - index_t dout_batch_stride; - index_t dout_c_stride; - index_t dout_l_stride; - - // Common data pointers. - void *__restrict__ dx_ptr; - void *__restrict__ dweight_ptr; - void *__restrict__ dbias_ptr; - void *__restrict__ dout_ptr; -}; - diff --git a/causal-conv1d/csrc/causal_conv1d_bwd.cu b/causal-conv1d/csrc/causal_conv1d_bwd.cu deleted file mode 100644 index 66609750a30a86a284451871ca163d79a0529047..0000000000000000000000000000000000000000 --- a/causal-conv1d/csrc/causal_conv1d_bwd.cu +++ /dev/null @@ -1,525 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include <c10/util/BFloat16.h> -#include <c10/util/Half.h> -#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK - -#include <cub/block/block_load.cuh> -#include <cub/block/block_store.cuh> -#include <cub/block/block_reduce.cuh> - -#include "causal_conv1d.h" -#include "causal_conv1d_common.h" -#include "static_switch.h" - -template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_> -struct Causal_conv1d_bwd_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr bool kSiluAct = kSiluAct_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static_assert(kWidth <= kNElts); - // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits - // (since then we'd have 8 values of float, and each round we can exchange 4 floats). - static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType<kNBytes * kNElts>::Type; - using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>; - using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>; - using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>; - using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>; - static constexpr int kSmemIOSize = kIsVecLoad - ? 0 - : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); - static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1); - static constexpr int kSmemSize = std::max({kSmemExchangeSize, - int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize); -}; - -template<typename Ktraits> -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_bwd_kernel(ConvParamsBwd params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr bool kSiluAct = Ktraits::kSiluAct; - constexpr int kNElts = Ktraits::kNElts; - constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds; - constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - extern __shared__ char smem_[]; - auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_); - auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_); - auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_); - auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_); - vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize); - vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds; - auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize); - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride - + dim_id * params.x_c_stride; - weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride; - input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride - + dim_id * params.dout_c_stride; - input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride - + dim_id * params.dx_c_stride; - float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]); - - // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0. - if (tidx == 0) { - if constexpr (!kSiluAct) { - input_t zeros[kNElts] = {0}; - smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0]; - } else { - float zeros[kNElts] = {0}; - #pragma unroll - for (int r = 0; r < kNExchangeRounds; ++r) { - smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r]; - } - } - } - - float weight_vals[kWidth]; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; } - - float dweight_vals[kWidth] = {0}; - float dbias_val = 0; - - constexpr int kChunkSize = kNThreads * kNElts; - const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; - x += (n_chunks - 1) * kChunkSize; - dout += (n_chunks - 1) * kChunkSize; - dx += (n_chunks - 1) * kChunkSize; - for (int chunk = n_chunks - 1; chunk >= 0; --chunk) { - input_t x_vals_load[2 * kNElts] = {0}; - input_t dout_vals_load[2 * kNElts] = {0}; - if constexpr(kIsVecLoad) { - Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); - Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - __syncthreads(); - Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); - __syncthreads(); - Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize); - } - float dout_vals[2 * kNElts], x_vals[2 * kNElts]; - if constexpr (!kSiluAct) { - __syncthreads(); - // Thread 0 don't write yet, so that thread kNThreads - 1 can read - // the first elements of the next chunk. - if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; } - __syncthreads(); - reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0]; - __syncthreads(); - // Now thread 0 can write the first elements of the current chunk. - if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; } - #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { - dout_vals[i] = float(dout_vals_load[i]); - x_vals[i] = float(x_vals_load[i]); - } - } else { - if (tidx == 0 && chunk > 0) { - if constexpr(kIsVecLoad) { - reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1]; - } else { - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; } - } - } - } - __syncthreads(); - smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; - __syncthreads(); - if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; } - #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } - // Recompute the output - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - float out_val = bias_val; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; - } - float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val)); - dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val - * (1.0f + out_val * (1.0f - out_sigmoid_val)); - } - // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange - // if input_t is 16 bits (since then we'd have 8 values of float) - __syncthreads(); - // Thread 0 don't write yet, so that thread kNThreads - 1 can read - // the first elements of the next chunk. - if (tidx > 0) { - #pragma unroll - for (int r = 0; r < kNExchangeRounds; ++r) { - smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r]; - } - } - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNExchangeRounds; ++r) { - reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r] - = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)]; - } - __syncthreads(); - // Now thread 0 can write the first elements of the current chunk. - if (tidx == 0) { - #pragma unroll - for (int r = 0; r < kNExchangeRounds; ++r) { - smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r]; - } - } - } - dout -= kChunkSize; - x -= kChunkSize; - - #pragma unroll - for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; } - - float dx_vals[kNElts] = {0}; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1]; - } - } - - input_t dx_vals_store[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; } - if constexpr(kIsVecLoad) { - Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize); - } - dx -= kChunkSize; - - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1]; - } - } - } - - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - __syncthreads(); - dweight_vals[w] = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]); - if (tidx == 0) { - atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]); - } - } - if (params.bias_ptr != nullptr) { - __syncthreads(); - dbias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val); - if (tidx == 0) { - atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val); - } - } -} - -template<int kNThreads, int kWidth, typename input_t, typename weight_t> -void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) { - static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { - BOOL_SWITCH(params.silu_activation, kSiluAct, [&] { - using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>; - constexpr int kSmemSize = Ktraits::kSmemSize; - dim3 grid(params.batch, params.dim); - auto kernel = &causal_conv1d_bwd_kernel<Ktraits>; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); -} - -template<typename input_t, typename weight_t> -void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - -template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_> -struct Causal_conv1d_channellast_bwd_kernel_traits { - // The cache line is 128 bytes, and we try to read 16 bytes per thread. - // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. - // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 - // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr bool kSiluAct = kSiluAct_; - static constexpr int kNThreads = kNThreads_; - static_assert(kNThreads % 32 == 0); - static constexpr int kNWarps = kNThreads / 32; - static constexpr int kWidth = kWidth_; - static constexpr int kChunkSizeL = kChunkSizeL_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static constexpr int kNEltsPerRow = 128 / kNBytes; - static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now - static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); - static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now - static_assert(kNColsPerWarp * kNThreadsPerRow == 32); - static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; - static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; - static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType<kNBytes * kNElts>::Type; - // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>; - // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), - // sizeof(typename BlockStoreT::TempStorage)}); - // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; -}; - -template<typename Ktraits> -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr bool kSiluAct = Ktraits::kSiluAct; - constexpr int kNElts = Ktraits::kNElts; - constexpr int kNWarp = Ktraits::kNWarps; - constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; - constexpr int kLPerLoad = Ktraits::kNColsPerLoad; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts]; - __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts]; - - const int tid = threadIdx.x; - const int l_idx = tid / kNThreadsPerC; - const int c_idx = tid % kNThreadsPerC; - const int batch_id = blockIdx.x; - const int chunk_l_id = blockIdx.y; - const int chunk_c_id = blockIdx.z; - input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) - + chunk_c_id * kChunkSizeC * params.weight_c_stride; - input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - float *dweight = reinterpret_cast<float *>(params.dweight_ptr) - + chunk_c_id * kChunkSizeC * params.dweight_c_stride; - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t dout_vals_load[kNElts] = {0}; - input_t x_vals_load[kNElts] = {0}; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride); - reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride); - } - reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; - reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0]; - } - // Load the elements from the previous chunk or next chunk that are needed for convolution. - if (l_idx < kWidth - 1) { - input_t dout_vals_load[kNElts] = {0}; - input_t x_vals_load[kNElts] = {0}; - if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride); - } - if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride); - } - reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; - reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0]; - } - // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs - if constexpr (kSiluAct) { - if (l_idx < kWidth - 1) { - input_t x_vals_load[kNElts] = {0}; - if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride); - } - reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0]; - } - } - - __syncthreads(); - - constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); - static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); - constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; - static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); - // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity - static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); - static_assert((kLPerThread & (kLPerThread - 1)) == 0); - static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); - static_assert(kNThreadsPerRow <= 32); - - const int row_idx = tid / kNThreadsPerRow; - const int col_idx = tid % kNThreadsPerRow; - - float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); - float weight_vals[kWidth] = {0}; - if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; - } - } - float dout_vals[kLPerThread + kWidth - 1]; - float x_vals[kWidth - 1 + kLPerThread + kWidth - 1]; - #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]); - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); - } - - if constexpr (kSiluAct) { // Recompute the output - #pragma unroll - for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) { - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); - } - #pragma unroll - for (int i = 0; i < kLPerThread + kWidth - 1; ++i) { - float out_val = bias_val; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { out_val += weight_vals[w] * x_vals[i + w]; } - float out_val_sigmoid = 1.f / (1.f + expf(-out_val)); - dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid)); - } - } - - float dweight_vals[kWidth] = {0}; - SumOp<float> sum_op; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { dweight_vals[w] += x_vals[i + w] * dout_vals[i]; } - dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op); - if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { - atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]); - } - } - - if (params.bias_ptr != nullptr) { - float dbias_val = 0.f; - for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; } - dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op); - if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { - atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val); - } - } - - float dx_vals[kLPerThread] = {0}; - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; } - } - // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads. - __syncwarp(); - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = dx_vals[i]; } - __syncthreads(); - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t dx_vals_store[kNElts]; - reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx]; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - *reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0]; - } - } - -} - -template<int kNThreads, int kWidth, typename input_t, typename weight_t> -void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.silu_activation, kSiluAct, [&] { - using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, 64, kSiluAct, true, input_t, weight_t>; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; - const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; - dim3 grid(params.batch, n_chunks_L, n_chunks_C); - dim3 block(Ktraits::kNThreads); - auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits>; - // if (kSmemSize >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - // } - // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); - kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -template<typename input_t, typename weight_t> -void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - -template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream); - -template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/causal-conv1d/csrc/causal_conv1d_common.h b/causal-conv1d/csrc/causal_conv1d_common.h deleted file mode 100644 index 8dd6a333b52163986c085f71475709706ce8f9c3..0000000000000000000000000000000000000000 --- a/causal-conv1d/csrc/causal_conv1d_common.h +++ /dev/null @@ -1,64 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include <cuda_bf16.h> -#include <cuda_fp16.h> - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<int BYTES> struct BytesToType {}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<typename T> -struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } -}; - -template<int THREADS> -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template<typename T, typename Operator> - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce<OFFSET>::run(x, op); - } -}; - -template<> -struct Allreduce<2> { -template<typename T, typename Operator> -static __device__ inline T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; diff --git a/causal-conv1d/csrc/causal_conv1d_fwd.cu b/causal-conv1d/csrc/causal_conv1d_fwd.cu deleted file mode 100644 index 74a1459f88a87ef427075a25e5081899e382efc0..0000000000000000000000000000000000000000 --- a/causal-conv1d/csrc/causal_conv1d_fwd.cu +++ /dev/null @@ -1,350 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include <c10/util/BFloat16.h> -#include <c10/util/Half.h> -#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK - -#include <cub/block/block_load.cuh> -#include <cub/block/block_store.cuh> - -#include "causal_conv1d.h" -#include "causal_conv1d_common.h" -#include "static_switch.h" - -template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_> -struct Causal_conv1d_fwd_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static_assert(kWidth <= kNElts); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType<kNBytes * kNElts>::Type; - using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>; - using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>; - using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>; - static constexpr int kSmemIOSize = kIsVecLoad - ? 0 - : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); - static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; -}; - -template<typename Ktraits> -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_fwd_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - extern __shared__ char smem_[]; - auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_); - auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_); - auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_); - auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_); - vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize); - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y; - input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride - + channel_id * params.x_c_stride; - weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]); - - // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. - if (tidx == 0) { - input_t zeros[kNElts] = {0}; - smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0]; - } - - float weight_vals[kWidth]; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - - constexpr int kChunkSize = kNThreads * kNElts; - const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; - for (int chunk = 0; chunk < n_chunks; ++chunk) { - input_t x_vals_load[2 * kNElts] = {0}; - if constexpr(kIsVecLoad) { - Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - __syncthreads(); - Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); - } - x += kChunkSize; - __syncthreads(); - // Thread kNThreads - 1 don't write yet, so that thread 0 can read - // the last elements of the previous chunk. - if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; } - __syncthreads(); - reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; - __syncthreads(); - // Now thread kNThreads - 1 can write the last elements of the current chunk. - if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; } - - float x_vals[2 * kNElts]; - #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } - - float out_vals[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = bias_val; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; - } - } - - if (params.silu_activation) { - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); - } - } - - input_t out_vals_store[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } - if constexpr(kIsVecLoad) { - Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); - } - out += kChunkSize; - } -} - -template<int kNThreads, int kWidth, typename input_t, typename weight_t> -void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { - using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>; - constexpr int kSmemSize = Ktraits::kSmemSize; - dim3 grid(params.batch, params.dim); - auto kernel = &causal_conv1d_fwd_kernel<Ktraits>; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -template<typename input_t, typename weight_t> -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - -template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_> -struct Causal_conv1d_channellast_fwd_kernel_traits { - // The cache line is 128 bytes, and we try to read 16 bytes per thread. - // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. - // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 - // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static_assert(kNThreads % 32 == 0); - static constexpr int kNWarps = kNThreads / 32; - static constexpr int kWidth = kWidth_; - static constexpr int kChunkSizeL = kChunkSizeL_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static constexpr int kNEltsPerRow = 128 / kNBytes; - static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now - static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); - static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now - static_assert(kNColsPerWarp * kNThreadsPerRow == 32); - static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; - static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; - static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType<kNBytes * kNElts>::Type; - // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>; - // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), - // sizeof(typename BlockStoreT::TempStorage)}); - // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; -}; - -template<typename Ktraits> -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - constexpr int kNWarp = Ktraits::kNWarps; - constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; - constexpr int kLPerLoad = Ktraits::kNColsPerLoad; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; - - const int tid = threadIdx.x; - const int l_idx = tid / kNThreadsPerC; - const int c_idx = tid % kNThreadsPerC; - const int batch_id = blockIdx.x; - const int chunk_l_id = blockIdx.y; - const int chunk_c_id = blockIdx.z; - input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) - + chunk_c_id * kChunkSizeC * params.weight_c_stride; - input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t x_vals_load[kNElts] = {0}; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride); - } - reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0]; - } - // Load the elements from the previous chunk that are needed for convolution. - if (l_idx < kWidth - 1) { - input_t x_vals_load[kNElts] = {0}; - if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride); - } - reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0]; - } - - __syncthreads(); - - constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); - static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); - constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; - static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); - // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity - static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); - static_assert((kLPerThread & (kLPerThread - 1)) == 0); - static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); - static_assert(kNThreadsPerRow <= 32); - - const int row_idx = tid / kNThreadsPerRow; - const int col_idx = tid % kNThreadsPerRow; - - float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); - float weight_vals[kWidth] = {0}; - if (chunk_c_id + kChunkSizeC + row_idx < params.dim) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; - } - } - float x_vals[kWidth - 1 + kLPerThread]; - #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); - } - - float out_vals[kLPerThread]; - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - out_vals[i] = bias_val; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { out_vals[i] += weight_vals[w] * x_vals[i + w]; } - if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } - } - - // Since kNThreadsPerRow is a power of 2 and <= 32, we only need syncwarp and not syncthreads. - __syncwarp(); - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } - __syncthreads(); - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t out_vals_store[kNElts]; - reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx]; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - *reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0]; - } - } - -} - -template<int kNThreads, int kWidth, typename input_t, typename weight_t> -void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; - const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; - // printf("n_chunks_L: %d, n_chunks_C: %d\n", n_chunks_L, n_chunks_C); - dim3 grid(params.batch, n_chunks_L, n_chunks_C); - dim3 block(Ktraits::kNThreads); - auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits>; - // if (kSmemSize >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - // } - // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); - kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template<typename input_t, typename weight_t> -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - -template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); - -template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/causal-conv1d/csrc/causal_conv1d_update.cu b/causal-conv1d/csrc/causal_conv1d_update.cu deleted file mode 100644 index 713e0ac883853491f9bdb0015b578657c228c1e7..0000000000000000000000000000000000000000 --- a/causal-conv1d/csrc/causal_conv1d_update.cu +++ /dev/null @@ -1,96 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include <c10/util/BFloat16.h> -#include <c10/util/Half.h> -#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK - -#include <cub/block/block_load.cuh> -#include <cub/block/block_store.cuh> - -#include "causal_conv1d.h" -#include "causal_conv1d_common.h" -#include "static_switch.h" - -template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_> -struct Causal_conv1d_update_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); -}; - -template<typename Ktraits> -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_update_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y * kNThreads + tidx; - input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride - + channel_id * params.x_c_stride; - input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride - + channel_id * params.conv_state_c_stride; - weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]); - - float weight_vals[kWidth] = {0}; - if (channel_id < params.dim) { - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - } - - float x_vals[kWidth] = {0}; - if (channel_id < params.dim) { - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } - x_vals[kWidth - 1] = float(x[0]); - #pragma unroll - for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } - } - - float out_val = bias_val; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } - if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } - if (channel_id < params.dim) { out[0] = input_t(out_val); } -} - -template<int kNThreads, int kWidth, typename input_t, typename weight_t> -void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>; - dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); - auto kernel = &causal_conv1d_update_kernel<Ktraits>; - kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template<typename input_t, typename weight_t> -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); - } -} - -template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/causal-conv1d/csrc/static_switch.h b/causal-conv1d/csrc/static_switch.h deleted file mode 100644 index 0f4ad3eb62235443d15c454b6691c2ec63645219..0000000000000000000000000000000000000000 --- a/causal-conv1d/csrc/static_switch.h +++ /dev/null @@ -1,25 +0,0 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function<BoolConst>(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - static constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - static constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/causal-conv1d/setup.py b/causal-conv1d/setup.py deleted file mode 100644 index 12e36bf988215a4c536278026e6f4401e66534da..0000000000000000000000000000000000000000 --- a/causal-conv1d/setup.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright (c) 2023, Tri Dao. -import sys -import warnings -import os -import re -import ast -from pathlib import Path -from packaging.version import parse, Version -import platform - -from setuptools import setup, find_packages -import subprocess - -import urllib.request -import urllib.error -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel - -import torch -from torch.utils.cpp_extension import ( - BuildExtension, - CppExtension, - CUDAExtension, - CUDA_HOME, -) - - -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - -PACKAGE_NAME = "causal_conv1d" - -BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}" - -# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels -# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation -FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE" -# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE" - - -def get_platform(): - """ - Returns the platform name as used in wheel filenames. - """ - if sys.platform.startswith("linux"): - return "linux_x86_64" - elif sys.platform == "darwin": - mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) - return f"macosx_{mac_version}_x86_64" - elif sys.platform == "win32": - return "win_amd64" - else: - raise ValueError("Unsupported platform: {}".format(sys.platform)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True - ) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary - # in that case. - warnings.warn( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - return nvcc_extra_args + ["--threads", "4"] - - -cmdclass = {} -ext_modules = [] - -if not SKIP_CUDA_BUILD: - print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - - check_if_cuda_home_none("causal_conv1d") - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - if CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version < Version("11.6"): - raise RuntimeError( - "causal_conv1d is only supported on CUDA 11.6 and above. " - "Note: make sure nvcc has a supported version by running nvcc -V." - ) - - cc_flag.append("-gencode") - cc_flag.append("arch=compute_70,code=sm_70") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - - # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as - # torch._C._GLIBCXX_USE_CXX11_ABI - # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 - if FORCE_CXX11_ABI: - torch._C._GLIBCXX_USE_CXX11_ABI = True - - ext_modules.append( - CUDAExtension( - name="causal_conv1d_cuda", - sources=[ - "csrc/causal_conv1d.cpp", - "csrc/causal_conv1d_fwd.cu", - "csrc/causal_conv1d_bwd.cu", - "csrc/causal_conv1d_update.cu", - ], - extra_compile_args={ - "cxx": ["-O3"], - "nvcc": append_nvcc_threads( - [ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - "-lineinfo", - ] - + cc_flag - ), - }, - include_dirs=[this_dir], - ) - ) - - -def get_package_version(): - with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f: - version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) - public_version = ast.literal_eval(version_match.group(1)) - local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION") - if local_version: - return f"{public_version}+{local_version}" - else: - return str(public_version) - - -def get_wheel_url(): - # Determine the version numbers that will be used to determine the correct wheel - # We're using the CUDA version used to build torch, not the one currently installed - # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - torch_cuda_version = parse(torch.version.cuda) - torch_version_raw = parse(torch.__version__) - # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 - # to save CI time. Minor versions should be compatible. - torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") - python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" - platform_name = get_platform() - causal_conv1d_version = get_package_version() - # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" - cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" - torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" - cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() - - # Determine wheel URL based on CUDA version, torch version, python version and OS - wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" - wheel_url = BASE_WHEEL_URL.format( - tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename - ) - return wheel_url, wheel_filename - - -class CachedWheelsCommand(_bdist_wheel): - """ - The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot - find an existing wheel (which is currently the case for all installs). We use - the environment parameters to detect whether there is already a pre-built version of a compatible - wheel available and short-circuits the standard full build pipeline. - """ - - def run(self): - if FORCE_BUILD: - return super().run() - - wheel_url, wheel_filename = get_wheel_url() - print("Guessing wheel URL: ", wheel_url) - try: - urllib.request.urlretrieve(wheel_url, wheel_filename) - - # Make the archive - # Lifted from the root wheel processing command - # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 - if not os.path.exists(self.dist_dir): - os.makedirs(self.dist_dir) - - impl_tag, abi_tag, plat_tag = self.get_tag() - archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - - wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") - print("Raw wheel path", wheel_path) - os.rename(wheel_filename, wheel_path) - except urllib.error.HTTPError: - print("Precompiled wheel not found. Building from source...") - # If the wheel could not be downloaded, build from source - super().run() - - -setup( - name=PACKAGE_NAME, - version=get_package_version(), - packages=find_packages( - exclude=( - "build", - "csrc", - "include", - "tests", - "dist", - "docs", - "benchmarks", - "causal_conv1d.egg-info", - ) - ), - author="Tri Dao", - author_email="tri@tridao.me", - description="Causal depthwise conv1d in CUDA, with a PyTorch interface", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/Dao-AILab/causal-conv1d", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - ], - ext_modules=ext_modules, - cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} - if ext_modules - else { - "bdist_wheel": CachedWheelsCommand, - }, - python_requires=">=3.7", - install_requires=[ - "torch", - "packaging", - "ninja", - ], -) diff --git a/causal-conv1d/tests/test_causal_conv1d.py b/causal-conv1d/tests/test_causal_conv1d.py deleted file mode 100644 index 6e5985cfb0582e6656afb1d8b5c1de78f24f4276..0000000000000000000000000000000000000000 --- a/causal-conv1d/tests/test_causal_conv1d.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (C) 2023, Tri Dao. - -import math - -import torch -import pytest - -from einops import rearrange - -from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref -from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref - - -@pytest.mark.parametrize("channel_last", [False, True]) -# @pytest.mark.parametrize('channel_last', [True]) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize('itype', [torch.float16]) -@pytest.mark.parametrize("silu_activation", [False, True]) -# @pytest.mark.parametrize('silu_activation', [True]) -@pytest.mark.parametrize("has_bias", [False, True]) -# @pytest.mark.parametrize('has_bias', [True]) -@pytest.mark.parametrize("width", [2, 3, 4]) -# @pytest.mark.parametrize('width', [2]) -@pytest.mark.parametrize( - "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] -) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [128]) -def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last): - device = "cuda" - rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 1e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - # set seed - torch.random.manual_seed(0) - batch_size = 2 - # batch_size = 1 - dim = 4096 + 32 # Try dim not divisible by 64 - # dim = 64 - if not channel_last: - x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() - else: - x = rearrange( - torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" - ).requires_grad_() - weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - else: - bias = None - x_ref = x.detach().clone().requires_grad_() - weight_ref = weight.detach().clone().requires_grad_() - bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None - activation = None if not silu_activation else "silu" - out = causal_conv1d_fn(x, weight, bias, activation=activation) - out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - - g = torch.randn_like(out) - out_ref.backward(g) - out.backward(g) - - print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") - print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") - if has_bias: - print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") - - assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) - assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) - if has_bias: - assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) - - -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize('itype', [torch.float16]) -@pytest.mark.parametrize("silu_activation", [False, True]) -# @pytest.mark.parametrize('silu_activation', [False]) -@pytest.mark.parametrize("has_bias", [False, True]) -# @pytest.mark.parametrize('has_bias', [True]) -@pytest.mark.parametrize("width", [2, 3, 4]) -# @pytest.mark.parametrize('width', [2]) -@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -# @pytest.mark.parametrize("dim", [2048]) -def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): - device = "cuda" - rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 1e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - # set seed - torch.random.manual_seed(0) - batch_size = 2 - # batch_size = 1 - # dim = 64 - x = torch.randn(batch_size, dim, device=device, dtype=itype) - conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype) - weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - else: - bias = None - conv_state_ref = conv_state.detach().clone() - activation = None if not silu_activation else "silu" - out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) - out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert torch.equal(conv_state, conv_state_ref) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - - -# @pytest.mark.parametrize("channel_last", [False, True]) -@pytest.mark.parametrize('channel_last', [True]) -# @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('itype', [torch.bfloat16]) -# @pytest.mark.parametrize("silu_activation", [False, True]) -@pytest.mark.parametrize('silu_activation', [True]) -# @pytest.mark.parametrize("has_bias", [False, True]) -@pytest.mark.parametrize('has_bias', [True]) -# @pytest.mark.parametrize("width", [2, 3, 4]) -@pytest.mark.parametrize('width', [4]) -@pytest.mark.parametrize( - # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] - "seqlen", [2048] -) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [128]) -def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last): - device = "cuda" - # set seed - torch.random.manual_seed(0) - batch_size = 2 - # batch_size = 1 - dim = 4096 + 32 # Try dim not divisible by 64 - # dim = 64 - if not channel_last: - x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() - else: - x = rearrange( - torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" - ).requires_grad_() - weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - else: - bias = None - activation = None if not silu_activation else "silu" - out0 = causal_conv1d_fn(x, weight, bias, activation=activation) - g = torch.randn_like(out0) - dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g) - dw_atol = 1e-4 - db_atol = 1e-4 - - for i in range(10000): - out = causal_conv1d_fn(x, weight, bias, activation=activation) - dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g) - dw_equal = torch.allclose(dw, dw0, atol=dw_atol) - # if not dw_equal: - # breakpoint() - if has_bias: - db_equal = torch.allclose(db, db0, atol=db_atol) - # if not db_equal: - # breakpoint() - assert torch.equal(out, out0) - assert torch.equal(dx, dx0) - assert dw_equal - if has_bias: - assert dw_equal diff --git a/causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl b/causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..3c346180501b1f8dc77018794c5676256e5043a3 --- /dev/null +++ b/causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78328bff9f0cf4814aa3c4029d63aa75128694e07ddae688b16215e3d8a2e7e7 +size 8424758 diff --git a/install.sh b/install.sh deleted file mode 100644 index a0473cd3342036d5a523de17668cdc1d39bd4521..0000000000000000000000000000000000000000 --- a/install.sh +++ /dev/null @@ -1,3 +0,0 @@ -pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118 -pip install -e causal-conv1d -pip install -e mamba \ No newline at end of file diff --git a/mamba/.gitmodules b/mamba/.gitmodules deleted file mode 100644 index a7445800fb64f3ae664c0b994a54235105986d2e..0000000000000000000000000000000000000000 --- a/mamba/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "3rdparty/lm-evaluation-harness"] - path = 3rdparty/lm-evaluation-harness - url = https://github.com/EleutherAI/lm-evaluation-harness/ diff --git a/mamba/AUTHORS b/mamba/AUTHORS deleted file mode 100644 index 38557a872f8d603ed963a05c211de7032de5926b..0000000000000000000000000000000000000000 --- a/mamba/AUTHORS +++ /dev/null @@ -1,2 +0,0 @@ -Tri Dao, tri@tridao.me -Albert Gu, agu@andrew.cmu.edu diff --git a/mamba/LICENSE b/mamba/LICENSE deleted file mode 100644 index f4abe24eb520fbb077753ae4f34bfaa43cb3b83f..0000000000000000000000000000000000000000 --- a/mamba/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2023 Tri Dao, Albert Gu - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/mamba/README.md b/mamba/README.md deleted file mode 100644 index 754cefd7f862a90bad8fbdff71e3793a4e7849e3..0000000000000000000000000000000000000000 --- a/mamba/README.md +++ /dev/null @@ -1,149 +0,0 @@ -# Mamba - - -> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ -> Albert Gu*, Tri Dao*\ -> Paper: https://arxiv.org/abs/2312.00752 - -## About - -Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. -It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4), -with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention). - -## Installation - -- `pip install causal-conv1d`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. -- `pip install mamba-ssm`: the core Mamba package. - -It can also be built from source with `pip install .` from this repository. - -If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`. - -Other requirements: -- Linux -- NVIDIA GPU -- PyTorch 1.12+ -- CUDA 11.6+ - -## Usage - -We expose several levels of interface with the Mamba model. - -### Selective SSM - -Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2). - -Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py). - -### Mamba Block - -The main module of this repository is the Mamba architecture block wrapping the selective SSM. - -Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py). - -Usage: -``` -from mamba_ssm import Mamba - -batch, length, dim = 2, 64, 16 -x = torch.randn(batch, length, dim).to("cuda") -model = Mamba( - # This module uses roughly 3 * expand * d_model^2 parameters - d_model=dim, # Model dimension d_model - d_state=16, # SSM state expansion factor - d_conv=4, # Local convolution width - expand=2, # Block expansion factor -).to("cuda") -y = model(x) -assert y.shape == x.shape -``` - -### Mamba Language Model - -Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. - -Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py). - -This is an example of how to integrate Mamba into an end-to-end neural network. -This example is used in the generation scripts below. - - - -## Pretrained Models - -Pretrained models are uploaded to -[HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`, -`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`. - -The models will be autodownloaded by the generation script below. - -These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models: - -| Parameters | Layers | Model dim. | -|------------|--------|------------| -| 130M | 12 | 768 | -| 370M | 24 | 1024 | -| 790M | 24 | 1536 | -| 1.4B | 24 | 2048 | -| 2.8B | 32 | 2560 | - -(The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.) - -Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). -Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models. - - -## Evaluations - -To run zero-shot evaluations of models (corresponding to Table 3 of the paper), -we use the -[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) -library. - -1. Pull the `lm-evaluation-harness` repo by `git submodule update --init - --recursive`. We use the `big-refactor` branch. -2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness` -3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo): -``` -python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 -python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 -``` - -Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process. - -## Inference - -The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py) -1. autoloads a model from the HuggingFace Hub, -2. generates completions of a user-specified prompt, -3. benchmarks the inference speed of this generation. - -Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature. - -### Examples - -To test generation latency (e.g. batch size = 1) with different sampling strategies: - -``` -python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5 -python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5 -``` - -To test generation throughput with random prompts (e.g. large batch size): -``` -python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128 -python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128 -``` - -## Citation - -If you use this codebase, or otherwise found our work valuable, please cite Mamba: -``` -@article{mamba, - title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, - author={Gu, Albert and Dao, Tri}, - journal={arXiv preprint arXiv:2312.00752}, - year={2023} -} -``` diff --git a/mamba/assets/selection.png b/mamba/assets/selection.png deleted file mode 100644 index 69b109a8eed4e3c7516b23e2b39d37e842a4464b..0000000000000000000000000000000000000000 Binary files a/mamba/assets/selection.png and /dev/null differ diff --git a/mamba/benchmarks/benchmark_generation_mamba_simple.py b/mamba/benchmarks/benchmark_generation_mamba_simple.py deleted file mode 100644 index 8f2943cb4bde6f25eddb82b7b999c5c5f8b39acc..0000000000000000000000000000000000000000 --- a/mamba/benchmarks/benchmark_generation_mamba_simple.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import argparse -import time -import json - -import torch -import torch.nn.functional as F - -from einops import rearrange - -from transformers import AutoTokenizer, AutoModelForCausalLM - -from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel - - -parser = argparse.ArgumentParser(description="Generation benchmarking") -parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") -parser.add_argument("--prompt", type=str, default=None) -parser.add_argument("--promptlen", type=int, default=100) -parser.add_argument("--genlen", type=int, default=100) -parser.add_argument("--temperature", type=float, default=1.0) -parser.add_argument("--topk", type=int, default=1) -parser.add_argument("--topp", type=float, default=1.0) -parser.add_argument("--batch", type=int, default=1) -args = parser.parse_args() - -repeats = 3 -device = "cuda" -dtype = torch.float16 - -print(f"Loading model {args.model_name}") -is_mamba = args.model_name.startswith("state-spaces/mamba-") or "mamba" in args.model_name - -if is_mamba: - tokenizer = AutoTokenizer.from_pretrained("/home/zhulianghui/VisionProjects/mamba/ckpts/gpt-neox-20b-tokenizer") - model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) -else: - tokenizer = AutoTokenizer.from_pretrained(args.model_name) - model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) -model.eval() -print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") - -torch.random.manual_seed(0) -if args.prompt is None: - input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") - attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") -else: - tokens = tokenizer(args.prompt, return_tensors="pt") - input_ids = tokens.input_ids.to(device=device) - attn_mask = tokens.attention_mask.to(device=device) -max_length = input_ids.shape[1] + args.genlen - -if is_mamba: - fn = lambda: model.generate( - input_ids=input_ids, - max_length=max_length, - cg=True, - return_dict_in_generate=True, - output_scores=True, - enable_timing=False, - temperature=args.temperature, - top_k=args.topk, - top_p=args.topp, - ) -else: - fn = lambda: model.generate( - input_ids=input_ids, - attention_mask=attn_mask, - max_length=max_length, - return_dict_in_generate=True, - pad_token_id=tokenizer.eos_token_id, - do_sample=True, - temperature=args.temperature, - top_k=args.topk, - top_p=args.topp, - ) -out = fn() -if args.prompt is not None: - print(tokenizer.batch_decode(out.sequences.tolist())) - -torch.cuda.synchronize() -start = time.time() -for _ in range(repeats): - fn() -torch.cuda.synchronize() -print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") -print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") diff --git a/mamba/csrc/selective_scan/reverse_scan.cuh b/mamba/csrc/selective_scan/reverse_scan.cuh deleted file mode 100644 index d7e93174bb391d45271e6c77669a5e52d6c9cc78..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/reverse_scan.cuh +++ /dev/null @@ -1,401 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include <cub/config.cuh> - -#include <cub/util_ptx.cuh> -#include <cub/util_type.cuh> -#include <cub/block/block_raking_layout.cuh> -// #include <cub/detail/uninitialized_copy.cuh> -#include "uninitialized_copy.cuh" - -/** - * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ReductionOp> -__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { - static_assert(LENGTH > 0); - T retval = input[LENGTH - 1]; - #pragma unroll - for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } - return retval; -} - -/** - * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ScanOp> -__device__ __forceinline__ T ThreadReverseScanInclusive( - const T (&input)[LENGTH], - T (&output)[LENGTH], - ScanOp scan_op, - const T postfix) -{ - T inclusive = postfix; - #pragma unroll - for (int i = LENGTH - 1; i >= 0; --i) { - inclusive = scan_op(inclusive, input[i]); - output[i] = inclusive; - } -} - -/** - * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ScanOp> -__device__ __forceinline__ T ThreadReverseScanExclusive( - const T (&input)[LENGTH], - T (&output)[LENGTH], - ScanOp scan_op, - const T postfix) -{ - // Careful, output maybe be aliased to input - T exclusive = postfix; - T inclusive; - #pragma unroll - for (int i = LENGTH - 1; i >= 0; --i) { - inclusive = scan_op(exclusive, input[i]); - output[i] = exclusive; - exclusive = inclusive; - } - return inclusive; -} - - -/** - * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. - * - * LOGICAL_WARP_THREADS must be a power-of-two - */ -template < - typename T, ///< Data type being scanned - int LOGICAL_WARP_THREADS ///< Number of threads per logical warp - > -struct WarpReverseScan { - //--------------------------------------------------------------------- - // Constants and type definitions - //--------------------------------------------------------------------- - - /// Whether the logical warp size and the PTX warp size coincide - static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); - /// The number of warp scan steps - static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE; - static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); - - - //--------------------------------------------------------------------- - // Thread fields - //--------------------------------------------------------------------- - - /// Lane index in logical warp - unsigned int lane_id; - - /// Logical warp index in 32-thread physical warp - unsigned int warp_id; - - /// 32-thread physical warp member mask of logical warp - unsigned int member_mask; - - //--------------------------------------------------------------------- - // Construction - //--------------------------------------------------------------------- - - /// Constructor - explicit __device__ __forceinline__ - WarpReverseScan() - : lane_id(cub::LaneId()) - , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) - , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id)) - { - if (!IS_ARCH_WARP) { - lane_id = lane_id % LOGICAL_WARP_THREADS; - } - } - - - /// Broadcast - __device__ __forceinline__ T Broadcast( - T input, ///< [in] The value to broadcast - int src_lane) ///< [in] Which warp lane is to do the broadcasting - { - return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask); - } - - - /// Inclusive scan - template <typename ScanOpT> - __device__ __forceinline__ void InclusiveReverseScan( - T input, ///< [in] Calling thread's input item. - T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. - ScanOpT scan_op) ///< [in] Binary scan operator - { - inclusive_output = input; - #pragma unroll - for (int STEP = 0; STEP < STEPS; STEP++) { - int offset = 1 << STEP; - T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>( - inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask - ); - // Perform scan op if from a valid peer - inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset - ? inclusive_output : scan_op(temp, inclusive_output); - } - } - - /// Exclusive scan - // Get exclusive from inclusive - template <typename ScanOpT> - __device__ __forceinline__ void ExclusiveReverseScan( - T input, ///< [in] Calling thread's input item. - T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. - ScanOpT scan_op, ///< [in] Binary scan operator - T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. - { - T inclusive_output; - InclusiveReverseScan(input, inclusive_output, scan_op); - warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask); - // initial value unknown - exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>( - inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask - ); - } - - /** - * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined. - */ - template <typename ScanOpT> - __device__ __forceinline__ void ReverseScan( - T input, ///< [in] Calling thread's input item. - T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. - T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. - ScanOpT scan_op) ///< [in] Binary scan operator - { - InclusiveReverseScan(input, inclusive_output, scan_op); - // initial value unknown - exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>( - inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask - ); - } - -}; - -/** - * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. - */ -template < - typename T, ///< Data type being scanned - int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension - bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure - > -struct BlockReverseScan { - //--------------------------------------------------------------------- - // Types and constants - //--------------------------------------------------------------------- - - /// Constants - /// The thread block size in threads - static constexpr int BLOCK_THREADS = BLOCK_DIM_X; - - /// Layout type for padded thread block raking grid - using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>; - // The number of reduction elements is not a multiple of the number of raking threads for now - static_assert(BlockRakingLayout::UNGUARDED); - - /// Number of raking threads - static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; - /// Number of raking elements per warp synchronous raking thread - static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; - /// Cooperative work can be entirely warp synchronous - static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); - - /// WarpReverseScan utility type - using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>; - - /// Shared memory storage layout type - struct _TempStorage { - typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid - }; - - - /// Alias wrapper allowing storage to be unioned - struct TempStorage : cub::Uninitialized<_TempStorage> {}; - - - //--------------------------------------------------------------------- - // Per-thread fields - //--------------------------------------------------------------------- - - // Thread fields - _TempStorage &temp_storage; - unsigned int linear_tid; - T cached_segment[SEGMENT_LENGTH]; - - - //--------------------------------------------------------------------- - // Utility methods - //--------------------------------------------------------------------- - - /// Performs upsweep raking reduction, returning the aggregate - template <typename ScanOp> - __device__ __forceinline__ T Upsweep(ScanOp scan_op) { - T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); - // Read data into registers - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } - T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; - #pragma unroll - for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { - raking_partial = scan_op(raking_partial, cached_segment[i]); - } - return raking_partial; - } - - - /// Performs exclusive downsweep raking scan - template <typename ScanOp> - __device__ __forceinline__ void ExclusiveDownsweep( - ScanOp scan_op, - T raking_partial) - { - T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); - // Read data back into registers - if (!MEMOIZE) { - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } - } - ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); - // Write data back to smem - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } - } - - - //--------------------------------------------------------------------- - // Constructors - //--------------------------------------------------------------------- - - /// Constructor - __device__ __forceinline__ BlockReverseScan( - TempStorage &temp_storage) - : - temp_storage(temp_storage.Alias()), - linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) - {} - - - /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. - template < - typename ScanOp, - typename BlockPostfixCallbackOp> - __device__ __forceinline__ void ExclusiveReverseScan( - T input, ///< [in] Calling thread's input item - T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) - ScanOp scan_op, ///< [in] Binary scan operator - BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. - { - if (WARP_SYNCHRONOUS) { - // Short-circuit directly to warp-synchronous scan - T block_aggregate; - WarpReverseScan warp_scan; - warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); - // Obtain warp-wide postfix in lane0, then broadcast to other lanes - T block_postfix = block_postfix_callback_op(block_aggregate); - block_postfix = warp_scan.Broadcast(block_postfix, 0); - exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); - } else { - // Place thread partial into shared memory raking grid - T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); - detail::uninitialized_copy(placement_ptr, input); - cub::CTA_SYNC(); - // Reduce parallelism down to just raking threads - if (linear_tid < RAKING_THREADS) { - WarpReverseScan warp_scan; - // Raking upsweep reduction across shared partials - T upsweep_partial = Upsweep(scan_op); - // Warp-synchronous scan - T exclusive_partial, block_aggregate; - warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); - // Obtain block-wide postfix in lane0, then broadcast to other lanes - T block_postfix = block_postfix_callback_op(block_aggregate); - block_postfix = warp_scan.Broadcast(block_postfix, 0); - // Update postfix with warpscan exclusive partial - T downsweep_postfix = linear_tid == RAKING_THREADS - 1 - ? block_postfix : scan_op(block_postfix, exclusive_partial); - // Exclusive raking downsweep scan - ExclusiveDownsweep(scan_op, downsweep_postfix); - } - cub::CTA_SYNC(); - // Grab thread postfix from shared memory - exclusive_output = *placement_ptr; - - // // Compute warp scan in each warp. - // // The exclusive output from the last lane in each warp is invalid. - // T inclusive_output; - // WarpReverseScan warp_scan; - // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); - - // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. - // T block_aggregate; - // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); - - // // Apply warp postfix to our lane's partial - // if (warp_id != 0) { - // exclusive_output = scan_op(warp_postfix, exclusive_output); - // if (lane_id == 0) { exclusive_output = warp_postfix; } - // } - - // // Use the first warp to determine the thread block postfix, returning the result in lane0 - // if (warp_id == 0) { - // T block_postfix = block_postfix_callback_op(block_aggregate); - // if (lane_id == 0) { - // // Share the postfix with all threads - // detail::uninitialized_copy(&temp_storage.block_postfix, - // block_postfix); - - // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 - // } - // } - - // cub::CTA_SYNC(); - - // // Incorporate thread block postfix into outputs - // T block_postfix = temp_storage.block_postfix; - // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } - } - } - - - /** - * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. - */ - template < - int ITEMS_PER_THREAD, - typename ScanOp, - typename BlockPostfixCallbackOp> - __device__ __forceinline__ void InclusiveReverseScan( - T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items - T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) - ScanOp scan_op, ///< [in] Binary scan functor - BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. - { - // Reduce consecutive thread items in registers - T thread_postfix = ThreadReverseReduce(input, scan_op); - // Exclusive thread block-scan - ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); - // Inclusive scan in registers with postfix as seed - ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); - } - -}; \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan.cpp b/mamba/csrc/selective_scan/selective_scan.cpp deleted file mode 100644 index f51af402a190dc14247ef8185a7d01b697313f02..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan.cpp +++ /dev/null @@ -1,497 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include <ATen/cuda/CUDAContext.h> -#include <c10/cuda/CUDAGuard.h> -#include <torch/extension.h> -#include <vector> - -#include "selective_scan.h" - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::ComplexFloat) { \ - using weight_t = c10::complex<float>; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -template<typename input_t, typename weight_t> -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); - -template <typename input_t, typename weight_t> -void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); - -void set_ssm_params_fwd(SSMParamsBase ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t dstate, - const size_t n_groups, - const size_t n_chunks, - const bool is_variable_B, - const bool is_variable_C, - // device pointers - const at::Tensor u, - const at::Tensor delta, - const at::Tensor A, - const at::Tensor B, - const at::Tensor C, - const at::Tensor out, - const at::Tensor z, - const at::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, - bool has_z, - bool delta_softplus) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.dstate = dstate; - params.n_groups = n_groups; - params.n_chunks = n_chunks; - params.dim_ngroups_ratio = dim / n_groups; - - params.delta_softplus = delta_softplus; - - params.is_variable_B = is_variable_B; - params.is_variable_C = is_variable_C; - - // Set the pointers and strides. - params.u_ptr = u.data_ptr(); - params.delta_ptr = delta.data_ptr(); - params.A_ptr = A.data_ptr(); - params.B_ptr = B.data_ptr(); - params.C_ptr = C.data_ptr(); - params.D_ptr = D_ptr; - params.delta_bias_ptr = delta_bias_ptr; - params.out_ptr = out.data_ptr(); - params.x_ptr = x_ptr; - params.z_ptr = has_z ? z.data_ptr() : nullptr; - params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; - // All stride are in elements, not bytes. - params.A_d_stride = A.stride(0); - params.A_dstate_stride = A.stride(1); - if (!is_variable_B) { - params.B_d_stride = B.stride(0); - } else { - params.B_batch_stride = B.stride(0); - params.B_group_stride = B.stride(1); - } - params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); - if (!is_variable_C) { - params.C_d_stride = C.stride(0); - } else { - params.C_batch_stride = C.stride(0); - params.C_group_stride = C.stride(1); - } - params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); - params.u_batch_stride = u.stride(0); - params.u_d_stride = u.stride(1); - params.delta_batch_stride = delta.stride(0); - params.delta_d_stride = delta.stride(1); - if (has_z) { - params.z_batch_stride = z.stride(0); - params.z_d_stride = z.stride(1); - params.out_z_batch_stride = out_z.stride(0); - params.out_z_d_stride = out_z.stride(1); - } - params.out_batch_stride = out.stride(0); - params.out_d_stride = out.stride(1); -} - -void set_ssm_params_bwd(SSMParamsBwd ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t dstate, - const size_t n_groups, - const size_t n_chunks, - const bool is_variable_B, - const bool is_variable_C, - // device pointers - const at::Tensor u, - const at::Tensor delta, - const at::Tensor A, - const at::Tensor B, - const at::Tensor C, - const at::Tensor z, - const at::Tensor out, - const at::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, - const at::Tensor dout, - const at::Tensor du, - const at::Tensor ddelta, - const at::Tensor dA, - const at::Tensor dB, - const at::Tensor dC, - const at::Tensor dz, - void* dD_ptr, - void* ddelta_bias_ptr, - bool has_z, - bool delta_softplus, - bool recompute_out_z) { - // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z - set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, has_z ? out : dout, - has_z ? z : dout, - // If not recompute_out_z, pass dout instead of out_z. - // This won't be used by the bwd kernel - recompute_out_z ? out_z : dout, - D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); - if (!recompute_out_z) { params.out_z_ptr = nullptr; } - - // Set the pointers and strides. - params.dout_ptr = dout.data_ptr(); - params.du_ptr = du.data_ptr(); - params.dA_ptr = dA.data_ptr(); - params.dB_ptr = dB.data_ptr(); - params.dC_ptr = dC.data_ptr(); - params.dD_ptr = dD_ptr; - params.ddelta_ptr = ddelta.data_ptr(); - params.ddelta_bias_ptr = ddelta_bias_ptr; - params.dz_ptr = has_z ? dz.data_ptr() : nullptr; - // All stride are in elements, not bytes. - params.dout_batch_stride = dout.stride(0); - params.dout_d_stride = dout.stride(1); - params.dA_d_stride = dA.stride(0); - params.dA_dstate_stride = dA.stride(1); - if (!is_variable_B) { - params.dB_d_stride = dB.stride(0); - } else { - params.dB_batch_stride = dB.stride(0); - params.dB_group_stride = dB.stride(1); - } - params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2); - if (!is_variable_C) { - params.dC_d_stride = dC.stride(0); - } else { - params.dC_batch_stride = dC.stride(0); - params.dC_group_stride = dC.stride(1); - } - params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2); - params.du_batch_stride = du.stride(0); - params.du_d_stride = du.stride(1); - params.ddelta_batch_stride = ddelta.stride(0); - params.ddelta_d_stride = ddelta.stride(1); - if (has_z) { - params.dz_batch_stride = dz.stride(0); - params.dz_d_stride = dz.stride(1); - } -} - -std::vector<at::Tensor> -selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, - const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, - const c10::optional<at::Tensor> &D_, - const c10::optional<at::Tensor> &z_, - const c10::optional<at::Tensor> &delta_bias_, - bool delta_softplus) { - auto input_type = u.scalar_type(); - auto weight_type = A.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); - - const bool is_variable_B = B.dim() >= 3; - const bool is_variable_C = C.dim() >= 3; - const bool is_complex = weight_type == at::ScalarType::ComplexFloat; - - TORCH_CHECK(delta.scalar_type() == input_type); - TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); - TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); - - TORCH_CHECK(u.is_cuda()); - TORCH_CHECK(delta.is_cuda()); - TORCH_CHECK(A.is_cuda()); - TORCH_CHECK(B.is_cuda()); - TORCH_CHECK(C.is_cuda()); - - TORCH_CHECK(u.stride(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1); - - const auto sizes = u.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.size(1); - const int n_groups = is_variable_B ? B.size(1) : 1; - - TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1); - } - - if (D_.has_value()) { - auto D = D_.value(); - TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1); - CHECK_SHAPE(delta_bias, dim); - } - - at::Tensor z, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - out_z = torch::empty_like(z); - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // at::Tensor out = torch::empty_like(u); - // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout - at::Tensor out = torch::empty_like(delta); - at::Tensor x; - x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); - - SSMParamsBase params; - set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, out, z, out_z, - D_.has_value() ? D_.value().data_ptr() : nullptr, - delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x.data_ptr(), - has_z, - delta_softplus); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { - selective_scan_fwd_cuda<input_t, weight_t>(params, stream); - }); - }); - std::vector<at::Tensor> result = {out, x}; - if (has_z) { result.push_back(out_z); } - return result; -} - -std::vector<at::Tensor> -selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, - const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, - const c10::optional<at::Tensor> &D_, - const c10::optional<at::Tensor> &z_, - const c10::optional<at::Tensor> &delta_bias_, - const at::Tensor &dout, - const c10::optional<at::Tensor> &x_, - const c10::optional<at::Tensor> &out_, - c10::optional<at::Tensor> &dz_, - bool delta_softplus, - bool recompute_out_z) { - auto input_type = u.scalar_type(); - auto weight_type = A.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); - - const bool is_variable_B = B.dim() >= 3; - const bool is_variable_C = C.dim() >= 3; - const bool is_complex = weight_type == at::ScalarType::ComplexFloat; - - TORCH_CHECK(delta.scalar_type() == input_type); - TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); - TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); - TORCH_CHECK(dout.scalar_type() == input_type); - - TORCH_CHECK(u.is_cuda()); - TORCH_CHECK(delta.is_cuda()); - TORCH_CHECK(A.is_cuda()); - TORCH_CHECK(B.is_cuda()); - TORCH_CHECK(C.is_cuda()); - TORCH_CHECK(dout.is_cuda()); - - TORCH_CHECK(u.stride(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1); - TORCH_CHECK(dout.stride(-1) == 1); - - const auto sizes = u.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.size(1); - const int n_groups = is_variable_B ? B.size(1) : 1; - - TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1); - } - CHECK_SHAPE(dout, batch_size, dim, seqlen); - - if (D_.has_value()) { - auto D = D_.value(); - TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1); - CHECK_SHAPE(delta_bias, dim); - } - - at::Tensor z, out, dz, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - - TORCH_CHECK(out_.has_value()); - out = out_.value(); - TORCH_CHECK(out.scalar_type() == input_type); - TORCH_CHECK(out.is_cuda()); - TORCH_CHECK(out.stride(-1) == 1); - CHECK_SHAPE(out, batch_size, dim, seqlen); - - if (dz_.has_value()) { - dz = dz_.value(); - TORCH_CHECK(dz.scalar_type() == input_type); - TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(dz.stride(-1) == 1); - CHECK_SHAPE(dz, batch_size, dim, seqlen); - } else { - dz = torch::empty_like(z); - } - if (recompute_out_z) { - out_z = torch::empty_like(out); - } - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } - if (x_.has_value()) { - auto x = x_.value(); - TORCH_CHECK(x.scalar_type() == weight_type); - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(x.is_contiguous()); - CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); - } - - at::Tensor du = torch::empty_like(u); - at::Tensor ddelta = torch::empty_like(delta); - at::Tensor dA = torch::zeros_like(A); - at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32)); - at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32)); - at::Tensor dD; - if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } - at::Tensor ddelta_bias; - if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } - - SSMParamsBwd params; - set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, z, out, out_z, - D_.has_value() ? D_.value().data_ptr() : nullptr, - delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x_.has_value() ? x_.value().data_ptr() : nullptr, - dout, du, ddelta, dA, dB, dC, dz, - D_.has_value() ? dD.data_ptr() : nullptr, - delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, - has_z, delta_softplus, recompute_out_z); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { - selective_scan_bwd_cuda<input_t, weight_t>(params, stream); - }); - }); - std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; - if (has_z) { result.push_back(dz); } - if (recompute_out_z) { result.push_back(out_z); } - return result; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fwd", &selective_scan_fwd, "Selective scan forward"); - m.def("bwd", &selective_scan_bwd, "Selective scan backward"); -} diff --git a/mamba/csrc/selective_scan/selective_scan.h b/mamba/csrc/selective_scan/selective_scan.h deleted file mode 100644 index e2c7bcdbd5ddadc5975caa641ecb1dcd3b73dafd..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan.h +++ /dev/null @@ -1,101 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMScanParamsBase { - using index_t = uint32_t; - - int batch, seqlen, n_chunks; - index_t a_batch_stride; - index_t b_batch_stride; - index_t out_batch_stride; - - // Common data pointers. - void *__restrict__ a_ptr; - void *__restrict__ b_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, dstate, n_groups, n_chunks; - int dim_ngroups_ratio; - bool is_variable_B; - bool is_variable_C; - - bool delta_softplus; - - index_t A_d_stride; - index_t A_dstate_stride; - index_t B_batch_stride; - index_t B_d_stride; - index_t B_dstate_stride; - index_t B_group_stride; - index_t C_batch_stride; - index_t C_d_stride; - index_t C_dstate_stride; - index_t C_group_stride; - index_t u_batch_stride; - index_t u_d_stride; - index_t delta_batch_stride; - index_t delta_d_stride; - index_t z_batch_stride; - index_t z_d_stride; - index_t out_batch_stride; - index_t out_d_stride; - index_t out_z_batch_stride; - index_t out_z_d_stride; - - // Common data pointers. - void *__restrict__ A_ptr; - void *__restrict__ B_ptr; - void *__restrict__ C_ptr; - void *__restrict__ D_ptr; - void *__restrict__ u_ptr; - void *__restrict__ delta_ptr; - void *__restrict__ delta_bias_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; - void *__restrict__ z_ptr; - void *__restrict__ out_z_ptr; -}; - -struct SSMParamsBwd: public SSMParamsBase { - index_t dout_batch_stride; - index_t dout_d_stride; - index_t dA_d_stride; - index_t dA_dstate_stride; - index_t dB_batch_stride; - index_t dB_group_stride; - index_t dB_d_stride; - index_t dB_dstate_stride; - index_t dC_batch_stride; - index_t dC_group_stride; - index_t dC_d_stride; - index_t dC_dstate_stride; - index_t du_batch_stride; - index_t du_d_stride; - index_t dz_batch_stride; - index_t dz_d_stride; - index_t ddelta_batch_stride; - index_t ddelta_d_stride; - - // Common data pointers. - void *__restrict__ dout_ptr; - void *__restrict__ dA_ptr; - void *__restrict__ dB_ptr; - void *__restrict__ dC_ptr; - void *__restrict__ dD_ptr; - void *__restrict__ du_ptr; - void *__restrict__ dz_ptr; - void *__restrict__ ddelta_ptr; - void *__restrict__ ddelta_bias_ptr; -}; diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu b/mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu deleted file mode 100644 index c55f0e858af4ebd246a5d251308ab920b4e01a50..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu b/mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu deleted file mode 100644 index 72adaf5cb13c6429e2f345a0a823c6bc3722b95a..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_bwd_bf16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu b/mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu deleted file mode 100644 index df126d7c8d5f9f0862273d2fe21ea15b35757b64..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu b/mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu deleted file mode 100644 index 3ff271b50eaff208ae33c16c87ab7aaee76dfd76..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_bwd_fp16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu b/mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu deleted file mode 100644 index 5554902342785b289b81c060a71a51734fc1e6bf..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu b/mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu deleted file mode 100644 index a7ed642231da80c455c0499702cc8a1cb4536ec2..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_bwd_fp32_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh deleted file mode 100644 index 2ed101148a4b32560111e5a832fc8b5881a4b243..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ /dev/null @@ -1,531 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include <c10/util/BFloat16.h> -#include <c10/util/Half.h> -#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK -#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex - -#include <cub/block/block_load.cuh> -#include <cub/block/block_store.cuh> -#include <cub/block/block_scan.cuh> -#include <cub/block/block_reduce.cuh> - -#include "selective_scan.h" -#include "selective_scan_common.h" -#include "reverse_scan.cuh" -#include "static_switch.h" - -template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x); -template<> __device__ __forceinline__ float conj<float>(float x) { return x; } -template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); } - -template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_, - bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_> -struct Selective_Scan_bwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kNItems = kNItems_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; - static constexpr bool kHasZ = kHasZ_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. - // For complex this would lead to massive register spilling, so we keep it at 2. - static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; - using vec_t = typename BytesToType<kNBytes * kNElts>::Type; - using scan_t = std::conditional_t<!kIsComplex, float2, float4>; - using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>; - using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>; - // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>; - using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>; - // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>; - using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>; - using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>; - using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>; - using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>; - using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>; - static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); - static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); -}; - -template<typename Ktraits> -__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) -void selective_scan_bwd_kernel(SSMParamsBwd params) { - constexpr bool kIsComplex = Ktraits::kIsComplex; - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan); - auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_); - auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_); - auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_); - auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize); - auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); - auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize); - auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce); - auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce); - auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize); - auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); - weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize); - scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads); - weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE); - weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride - + dim_id * params.u_d_stride; - input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride - + dim_id * params.delta_d_stride; - input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride - + dim_id * params.dout_d_stride; - weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride; - weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride; - input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; - weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride; - input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride; - weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr) - + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); - weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr) - + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); - float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id; - float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id]; - float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id; - float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id]; - scan_t *x = params.x_ptr == nullptr - ? nullptr - : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; - float dD_val = 0; - float ddelta_bias_val = 0; - - constexpr int kChunkSize = kNThreads * kNItems; - u += (params.n_chunks - 1) * kChunkSize; - delta += (params.n_chunks - 1) * kChunkSize; - dout += (params.n_chunks - 1) * kChunkSize; - Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); - Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); - for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { - input_t u_vals[kNItems]; - input_t delta_vals_load[kNItems]; - input_t dout_vals_load[kNItems]; - __syncthreads(); - load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); - u -= kChunkSize; - __syncthreads(); - load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - // Will reload delta at the same location if kDeltaSoftplus - if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } - __syncthreads(); - load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - dout -= kChunkSize; - - float dout_vals[kNItems], delta_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dout_vals[i] = float(dout_vals_load[i]); - delta_vals[i] = float(delta_vals_load[i]) + delta_bias; - if constexpr (kDeltaSoftplus) { - delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; - } - } - - if constexpr (kHasZ) { - input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride - + dim_id * params.z_d_stride + chunk * kChunkSize; - input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride - + dim_id * params.out_d_stride + chunk * kChunkSize; - input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride - + dim_id * params.dz_d_stride + chunk * kChunkSize; - input_t z_vals[kNItems], out_vals[kNItems]; - __syncthreads(); - load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); - __syncthreads(); - load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); - float dz_vals[kNItems], z_silu_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); - z_silu_vals[i] = z_val * z_sigmoid_val; - dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val - * (1.0f + z_val * (1.0f - z_sigmoid_val)); - dout_vals[i] *= z_silu_vals[i]; - } - __syncthreads(); - store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); - if (params.out_z_ptr != nullptr) { // Recompute and store out_z - float out_z_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } - // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); - // } - input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride - + dim_id * params.out_z_d_stride + chunk * kChunkSize; - __syncthreads(); - store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); - } - } - - float du_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } - - float ddelta_vals[kNItems] = {0}; - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - const weight_t A_val = A[state_idx * params.A_dstate_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. - weight_t A_scaled; - constexpr float kLog2e = M_LOG2E; - if constexpr (!kIsComplex) { - A_scaled = A_val * kLog2e; - } else { - A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); - } - weight_t B_val, C_val; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (!kIsVariableB) { - B_val = B[state_idx * params.B_dstate_stride]; - } else { - load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - } - if constexpr (!kIsVariableC) { - C_val = C[state_idx * params.C_dstate_stride]; - } else { - auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - } - // const weight_t A_val = smem_a[state_idx]; - scan_t thread_data[kNItems], thread_reverse_data[kNItems]; - if constexpr (!kIsComplex) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); - thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); - if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; - } else { - thread_reverse_data[i - 1].x = delta_a_exp; - } - thread_reverse_data[i].y = dout_vals[i] * - (!kIsVariableC - ? (!kIsVariableB ? B_val * C_val : C_val) - : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); - } - __syncthreads(); - thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) - : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; - // Initialize running total - scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix); - Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op - ); - scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix); - Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( - thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op - ); - if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } - weight_t dA_val = 0, dBC_val = 0; - weight_t dB_vals[kNItems], dC_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const float dx = thread_reverse_data[i].y; - const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; - du_vals[i] += ddelta_u * delta_vals[i]; - const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); - ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; - dA_val += dx * delta_vals[i] * a; - if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val - dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); - } else { // dBC_val is dC_val - dBC_val += dout_vals[i] * thread_data[i].y; - } - } - if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } - if constexpr (kIsVariableC) { - dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); - } - } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower - if constexpr (kIsVariableB || kIsVariableC) { - if constexpr (kIsVariableB) { - Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); - } - if constexpr (kIsVariableC) { - auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; - Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); - } - const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; - weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; - weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - if (i * kNThreads < seqlen_remaining) { - if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } - if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } - } - } - } - if constexpr (!kIsVariableB || !kIsVariableC) { - float2 dA_dBC_val = make_float2(dA_val, dBC_val); - dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); - dA_val = dA_dBC_val.x; - if (threadIdx.x == 0) { - smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; - } - } else { - dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); - } - if (threadIdx.x == 0) { - smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; - } - } else { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - // Pytorch's implementation of complex exp (which calls thrust) is very slow - complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); - weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); - thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); - if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; - } else { - thread_reverse_data[i - 1].x = delta_a_exp.real_; - thread_reverse_data[i - 1].y = -delta_a_exp.imag_; - } - complex_t dout_BC = 2 * dout_vals[i] - * conj(!kIsVariableC - ? (!kIsVariableB ? B_val * C_val : C_val) - : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); - thread_reverse_data[i].z = dout_BC.real_; - thread_reverse_data[i].w = dout_BC.imag_; - } - __syncthreads(); - complex_t delta_a_exp = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) - : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; - thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; - thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; - // Initialize running total - scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix); - Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op - ); - scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix); - Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( - thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op - ); - if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } - weight_t dA_val = 0, dBC_val = 0; - weight_t dB_vals[kNItems], dC_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - complex_t x = complex_t(thread_data[i].z, thread_data[i].w); - complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); - float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; - if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val - dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); - } else { // dBC_val is dC_val - dBC_val += (2 * dout_vals[i]) * conj(x); - } - } - const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); - du_vals[i] += ddelta_u * delta_vals[i]; - ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; - dA_val += delta_vals[i] * dx * a_conj; - if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } - if constexpr (kIsVariableC) { - dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); - } - } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower - if constexpr (kIsVariableB || kIsVariableC) { - float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; - if constexpr (kIsVariableB) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dB_vals_f[i * 2] = dB_vals[i].real_; - dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; - } - Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); - } - if constexpr (kIsVariableC) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dC_vals_f[i * 2] = dC_vals[i].real_; - dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; - } - auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; - Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); - } - const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; - float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; - float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; - #pragma unroll - for (int i = 0; i < kNItems * 2; ++i) { - if (i * kNThreads < seqlen_remaining) { - if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } - if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } - } - } - } - if constexpr (!kIsVariableB || !kIsVariableC) { - float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); - dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); - dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); - dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); - if (threadIdx.x == 0) { - smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; - } - } else { - dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); - } - if (threadIdx.x == 0) { - smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; - } - } - } - - if constexpr (kDeltaSoftplus) { - __syncthreads(); - input_t delta_vals_load[kNItems]; - load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - delta -= kChunkSize; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float delta_val = float(delta_vals_load[i]) + delta_bias; - float delta_val_neg_exp = expf(-delta_val); - ddelta_vals[i] = delta_val <= 20.f - ? ddelta_vals[i] / (1.f + delta_val_neg_exp) - : ddelta_vals[i]; - } - } - for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } - - input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride - + dim_id * params.du_d_stride + chunk * kChunkSize; - input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride - + dim_id * params.ddelta_d_stride + chunk * kChunkSize; - __syncthreads(); - store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); - __syncthreads(); - store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); - - Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); - Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); - } - if (params.dD_ptr != nullptr) { - dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); - if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } - } - if (params.ddelta_bias_ptr != nullptr) { - __syncthreads(); - ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); - if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } - } - for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); - weight_t dBC_val; - if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } - if constexpr (!kIsVariableB) { - gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), - !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); - } - if constexpr (!kIsVariableC) { - gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), - !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); - } - } -} - -template<int kNThreads, int kNItems, typename input_t, typename weight_t> -void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>; - // using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>; - // TODO: check this - constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim); - auto kernel = &selective_scan_bwd_kernel<Ktraits>; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); - }); -} - -template<typename input_t, typename weight_t> -void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { - if (params.seqlen <= 128) { - selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); - } -} \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_common.h b/mamba/csrc/selective_scan/selective_scan_common.h deleted file mode 100644 index 9140dcdf3b68ad2de95bcd3fd9543a9d320cef68..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_common.h +++ /dev/null @@ -1,221 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include <cuda_bf16.h> -#include <cuda_fp16.h> -#include <c10/util/complex.h> // For scalar_value_type - -#define MAX_DSTATE 256 - -using complex_t = c10::complex<float>; - -inline __device__ float2 operator+(const float2 & a, const float2 & b){ - return {a.x + b.x, a.y + b.y}; -} - -inline __device__ float3 operator+(const float3 &a, const float3 &b) { - return {a.x + b.x, a.y + b.y, a.z + b.z}; -} - -inline __device__ float4 operator+(const float4 & a, const float4 & b){ - return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<int BYTES> struct BytesToType {}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<typename scalar_t, int N> -struct Converter{ - static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { dst[i] = src[i]; } - } -}; - -template<int N> -struct Converter<at::Half, N>{ - static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src); - auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } - } -}; - -#if __CUDA_ARCH__ >= 800 -template<int N> -struct Converter<at::BFloat16, N>{ - static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src); - auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp -// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 -__device__ __forceinline__ complex_t cexp2f(complex_t z) { - float t = exp2f(z.real_); - float c, s; - sincosf(z.imag_, &s, &c); - return complex_t(c * t, s * t); -} - -__device__ __forceinline__ complex_t cexpf(complex_t z) { - float t = expf(z.real_); - float c, s; - sincosf(z.imag_, &s, &c); - return complex_t(c * t, s * t); -} - -template<typename scalar_t> struct SSMScanOp; - -template<> -struct SSMScanOp<float> { - __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { - return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); - } -}; - -template<> -struct SSMScanOp<complex_t> { - __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { - complex_t a0 = complex_t(ab0.x, ab0.y); - complex_t b0 = complex_t(ab0.z, ab0.w); - complex_t a1 = complex_t(ab1.x, ab1.y); - complex_t b1 = complex_t(ab1.z, ab1.w); - complex_t out_a = a1 * a0; - complex_t out_b = a1 * b0 + b1; - return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); - } -}; - -// A stateful callback functor that maintains a running prefix to be applied -// during consecutive scan operations. -template <typename scalar_t> struct SSMScanPrefixCallbackOp { - using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>; - scan_t running_prefix; - // Constructor - __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} - // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide scan. - __device__ scan_t operator()(scan_t block_aggregate) { - scan_t old_prefix = running_prefix; - running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate); - return old_prefix; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<typename Ktraits> -inline __device__ void load_input(typename Ktraits::input_t *u, - typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadT::TempStorage &smem_load, - int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockLoadVecT(smem_load_vec).Load( - reinterpret_cast<vec_t*>(u), - reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals) - ); - } else { - Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); - } -} - -template<typename Ktraits> -inline __device__ void load_weight(typename Ktraits::input_t *Bvar, - typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, - int seqlen) { - constexpr int kNItems = Ktraits::kNItems; - if constexpr (!Ktraits::kIsComplex) { - typename Ktraits::input_t B_vals_load[kNItems]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( - reinterpret_cast<vec_t*>(Bvar), - reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load) - ); - } else { - Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); - } - // #pragma unroll - // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } - Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals); - } else { - typename Ktraits::input_t B_vals_load[kNItems * 2]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( - reinterpret_cast<vec_t*>(Bvar), - reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load) - ); - } else { - Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); - } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } - } -} - -template<typename Ktraits> -inline __device__ void store_output(typename Ktraits::input_t *out, - const float (&out_vals)[Ktraits::kNItems], - typename Ktraits::BlockStoreT::TempStorage &smem_store, - int seqlen) { - typename Ktraits::input_t write_vals[Ktraits::kNItems]; - #pragma unroll - for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockStoreVecT(smem_store_vec).Store( - reinterpret_cast<vec_t*>(out), - reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals) - ); - } else { - Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); - } -} diff --git a/mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu b/mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu deleted file mode 100644 index 2b8615b1d522c119125d4cb6ff3dce42f2bd4659..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_fwd_bf16.cu +++ /dev/null @@ -1,10 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu b/mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu deleted file mode 100644 index 015e2a0eff633daf2693e43a2648008652a38c7c..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_fwd_fp16.cu +++ /dev/null @@ -1,10 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu b/mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu deleted file mode 100644 index c142fe0208ea784679122ba04997d3432b05efcc..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_fwd_fp32.cu +++ /dev/null @@ -1,10 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh deleted file mode 100644 index 440a209108bfe120c73d123bbf0b82ccf43a5638..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ /dev/null @@ -1,345 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include <c10/util/BFloat16.h> -#include <c10/util/Half.h> -#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK - -#include <cub/block/block_load.cuh> -#include <cub/block/block_store.cuh> -#include <cub/block/block_scan.cuh> - -#include "selective_scan.h" -#include "selective_scan_common.h" -#include "static_switch.h" - -template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_, - bool kIsVariableB_, bool kIsVariableC_, - bool kHasZ_, typename input_t_, typename weight_t_> -struct Selective_Scan_fwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. - static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; - static constexpr int kNItems = kNItems_; - static constexpr int kNRows = kNRows_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kHasZ = kHasZ_; - - static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; - - using vec_t = typename BytesToType<kNBytes * kNElts>::Type; - using scan_t = std::conditional_t<!kIsComplex, float2, float4>; - using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, - !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>; - using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, - !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>; - using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>; - using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, - !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>; - // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>; - // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>; - using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>; - static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); -}; - -template<typename Ktraits> -__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) -void selective_scan_fwd_kernel(SSMParamsBase params) { - constexpr bool kIsComplex = Ktraits::kIsComplex; - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - constexpr int kNRows = Ktraits::kNRows; - constexpr bool kDirectIO = Ktraits::kDirectIO; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan); - auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_); - auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_); - auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_); - auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize); - // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size); - // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE); - scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride - + dim_id * kNRows * params.u_d_stride; - input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride - + dim_id * kNRows * params.delta_d_stride; - weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride; - weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride; - input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; - weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride; - input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - - float D_val[kNRows] = {0}; - if (params.D_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r]; - } - } - float delta_bias[kNRows] = {0}; - if (params.delta_bias_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r]; - } - } - - // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; - // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; - // } - - constexpr int kChunkSize = kNThreads * kNItems; - for (int chunk = 0; chunk < params.n_chunks; ++chunk) { - input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); - if constexpr (!kDirectIO) { __syncthreads(); } - load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); - } - u += kChunkSize; - delta += kChunkSize; - - float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float u_val = float(u_vals[r][i]); - delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; - if (params.delta_softplus) { - delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; - } - delta_u_vals[r][i] = delta_vals[r][i] * u_val; - out_vals[r][i] = D_val[r] * u_val; - } - } - - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - weight_t A_val[kNRows]; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. - constexpr float kLog2e = M_LOG2E; - if constexpr (!kIsComplex) { - A_val[r] *= kLog2e; - } else { - A_val[r].real_ *= kLog2e; - } - } - // This variable holds B * C if both B and C are constant across seqlen. If only B varies - // across seqlen, this holds C. If only C varies across seqlen, this holds B. - // If both B and C vary, this is unused. - weight_t BC_val[kNRows]; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (kIsVariableB) { - load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - if constexpr (!kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } - } - } - if constexpr (kIsVariableC) { - auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - if constexpr (!kIsVariableB) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; - } - } - } - if constexpr (!kIsVariableB && !kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } - } - - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if (r > 0) { __syncthreads(); } // Scan could be using the same smem - scan_t thread_data[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - if constexpr (!kIsComplex) { - thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), - !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float2(1.f, 0.f); - } - } - } else { - // Pytorch's implementation of complex exp (which calls thrust) is very slow - complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); - weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; - thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); - } - } - } - } - // Initialize running total - scan_t running_prefix; - if constexpr (!kIsComplex) { - // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); - } else { - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - } - SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix); - Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op - ); - // There's a syncthreads in the scan op, so we don't need to sync here. - // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. - if (threadIdx.x == 0) { - smem_running_prefix[state_idx] = prefix_op.running_prefix; - x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; - } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const weight_t C_val = !kIsVariableC - ? BC_val[r] - : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); - if constexpr (!kIsComplex) { - out_vals[r][i] += thread_data[i].y * C_val; - } else { - out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2; - } - } - } - } - - input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride - + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); - } - - if constexpr (kHasZ) { - input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride - + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; - input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride - + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - input_t z_vals[kNItems]; - __syncthreads(); - load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - out_vals[r][i] *= z_val / (1 + expf(-z_val)); - } - __syncthreads(); - store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); - } - } - - Bvar += kChunkSize * (!kIsComplex ? 1 : 2); - Cvar += kChunkSize * (!kIsComplex ? 1 : 2); - } -} - -template<int kNThreads, int kNItems, typename input_t, typename weight_t> -void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { - // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block - // processing 1 row. - constexpr int kNRows = 1; - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel<Ktraits>; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); -} - -template<typename input_t, typename weight_t> -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { - if (params.seqlen <= 128) { - selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); - } -} diff --git a/mamba/csrc/selective_scan/static_switch.h b/mamba/csrc/selective_scan/static_switch.h deleted file mode 100644 index 7920ac045d0a2a1f4c4159ee3eebe51fe1e2c203..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/static_switch.h +++ /dev/null @@ -1,25 +0,0 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function<BoolConst>(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/mamba/csrc/selective_scan/uninitialized_copy.cuh b/mamba/csrc/selective_scan/uninitialized_copy.cuh deleted file mode 100644 index 630622dddcc9041737307810000584a843a01764..0000000000000000000000000000000000000000 --- a/mamba/csrc/selective_scan/uninitialized_copy.cuh +++ /dev/null @@ -1,69 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include <cub/config.cuh> - -#include <cuda/std/type_traits> - - -namespace detail -{ - -#if defined(_NVHPC_CUDA) -template <typename T, typename U> -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - // NVBug 3384810 - new (ptr) T(::cuda::std::forward<U>(val)); -} -#else -template <typename T, - typename U, - typename ::cuda::std::enable_if< - ::cuda::std::is_trivially_copyable<T>::value, - int - >::type = 0> -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - *ptr = ::cuda::std::forward<U>(val); -} - -template <typename T, - typename U, - typename ::cuda::std::enable_if< - !::cuda::std::is_trivially_copyable<T>::value, - int - >::type = 0> -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - new (ptr) T(::cuda::std::forward<U>(val)); -} -#endif - -} // namespace detail diff --git a/mamba/evals/lm_harness_eval.py b/mamba/evals/lm_harness_eval.py deleted file mode 100644 index d09d40534cf53be4d1387666697c82aa53add625..0000000000000000000000000000000000000000 --- a/mamba/evals/lm_harness_eval.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch - -import transformers -from transformers import AutoTokenizer - -from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel - -from lm_eval.api.model import LM -from lm_eval.models.huggingface import HFLM -from lm_eval.api.registry import register_model -from lm_eval.__main__ import cli_evaluate - - -@register_model("mamba") -class MambaEvalWrapper(HFLM): - - AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM - - def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", - dtype=torch.float16): - LM.__init__(self) - self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) - self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.vocab_size = self.tokenizer.vocab_size - self._batch_size = batch_size if batch_size is None else 64 - self._max_length = max_length - self._device = torch.device(device) - - @property - def batch_size(self): - return self._batch_size - - def _model_generate(self, context, max_length, stop, **generation_kwargs): - raise NotImplementedError() - - -if __name__ == "__main__": - cli_evaluate() diff --git a/mamba/mamba_ssm/__init__.py b/mamba/mamba_ssm/__init__.py deleted file mode 100644 index 2ecd144db5dbec72bcfcdcea28c624a7e2bf053b..0000000000000000000000000000000000000000 --- a/mamba/mamba_ssm/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -__version__ = "1.0.1" - -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn -from mamba_ssm.modules.mamba_simple import Mamba -from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel diff --git a/mamba/mamba_ssm/models/__init__.py b/mamba/mamba_ssm/models/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mamba/mamba_ssm/models/mixer_seq_simple.py b/mamba/mamba_ssm/models/mixer_seq_simple.py deleted file mode 100644 index 383f773f1f700cd53176e51327a5d8dc58158da0..0000000000000000000000000000000000000000 --- a/mamba/mamba_ssm/models/mixer_seq_simple.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) 2023, Albert Gu, Tri Dao. - -import math -from functools import partial - -from collections import namedtuple - -import torch -import torch.nn as nn - -from mamba_ssm.modules.mamba_simple import Mamba, Block -from mamba_ssm.utils.generation import GenerationMixin -from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - - -def create_block( - d_model, - ssm_cfg=None, - norm_epsilon=1e-5, - rms_norm=False, - residual_in_fp32=False, - fused_add_norm=False, - layer_idx=None, - device=None, - dtype=None, -): - if ssm_cfg is None: - ssm_cfg = {} - factory_kwargs = {"device": device, "dtype": dtype} - mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) - norm_cls = partial( - nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs - ) - block = Block( - d_model, - mixer_cls, - norm_cls=norm_cls, - fused_add_norm=fused_add_norm, - residual_in_fp32=residual_in_fp32, - ) - block.layer_idx = layer_idx - return block - - -# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 -def _init_weights( - module, - n_layer, - initializer_range=0.02, # Now only used for embedding layer. - rescale_prenorm_residual=True, - n_residuals_per_layer=1, # Change to 2 if we have MLP -): - if isinstance(module, nn.Linear): - if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=initializer_range) - - if rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight", "fc2.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) - # We need to reinit p since this code could be called multiple times - # Having just p *= scale would repeatedly scale it down - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - with torch.no_grad(): - p /= math.sqrt(n_residuals_per_layer * n_layer) - - -class MixerModel(nn.Module): - def __init__( - self, - d_model: int, - n_layer: int, - vocab_size: int, - ssm_cfg=None, - norm_epsilon: float = 1e-5, - rms_norm: bool = False, - initializer_cfg=None, - fused_add_norm=False, - residual_in_fp32=False, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - - self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) - - # We change the order of residual and layer norm: - # Instead of LN -> Attn / MLP -> Add, we do: - # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and - # the main branch (output of MLP / Mixer). The model definition is unchanged. - # This is for performance reason: we can fuse add + layer_norm. - self.fused_add_norm = fused_add_norm - if self.fused_add_norm: - if layer_norm_fn is None or rms_norm_fn is None: - raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") - - self.layers = nn.ModuleList( - [ - create_block( - d_model, - ssm_cfg=ssm_cfg, - norm_epsilon=norm_epsilon, - rms_norm=rms_norm, - residual_in_fp32=residual_in_fp32, - fused_add_norm=fused_add_norm, - layer_idx=i, - **factory_kwargs, - ) - for i in range(n_layer) - ] - ) - - self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( - d_model, eps=norm_epsilon, **factory_kwargs - ) - - self.apply( - partial( - _init_weights, - n_layer=n_layer, - **(initializer_cfg if initializer_cfg is not None else {}), - ) - ) - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return { - i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - for i, layer in enumerate(self.layers) - } - - def forward(self, input_ids, inference_params=None): - hidden_states = self.embedding(input_ids) - residual = None - for layer in self.layers: - hidden_states, residual = layer( - hidden_states, residual, inference_params=inference_params - ) - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) - else: - # Set prenorm=False here since we don't need the residual - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn - hidden_states = fused_add_norm_fn( - hidden_states, - self.norm_f.weight, - self.norm_f.bias, - eps=self.norm_f.eps, - residual=residual, - prenorm=False, - residual_in_fp32=self.residual_in_fp32, - ) - return hidden_states - - -class MambaLMHeadModel(nn.Module, GenerationMixin): - - def __init__( - self, - d_model: int, - n_layer: int, - vocab_size: int, - initializer_cfg=None, - pad_vocab_size_multiple: int = 1, - device=None, - dtype=None, - **backbone_kwargs, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - if vocab_size % pad_vocab_size_multiple != 0: - vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) - self.backbone = MixerModel( - d_model=d_model, - n_layer=n_layer, - vocab_size=vocab_size, - initializer_cfg=initializer_cfg, - **backbone_kwargs, - **factory_kwargs, - ) - self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) - - # Initialize weights and apply final processing - self.apply( - partial( - _init_weights, - n_layer=n_layer, - **(initializer_cfg if initializer_cfg is not None else {}), - ) - ) - self.tie_weights() - - def tie_weights(self): - self.lm_head.weight = self.backbone.embedding.weight - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): - """ - "position_ids" is just to be compatible with Transformer generation. We don't use it. - num_last_tokens: if > 0, only return the logits for the last n tokens - """ - hidden_states = self.backbone(input_ids, inference_params=inference_params) - if num_last_tokens > 0: - hidden_states = hidden_states[:, -num_last_tokens:] - lm_logits = self.lm_head(hidden_states) - CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) - return CausalLMOutput(logits=lm_logits) - - @classmethod - def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): - config = load_config_hf(pretrained_model_name) - model = cls(**config, device=device, dtype=dtype, **kwargs) - model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) - return model diff --git a/mamba/mamba_ssm/modules/__init__.py b/mamba/mamba_ssm/modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mamba/mamba_ssm/modules/mamba_simple.py b/mamba/mamba_ssm/modules/mamba_simple.py deleted file mode 100644 index 2a1dd8f808b50d632bbd22f0648d4cb8939cb1e1..0000000000000000000000000000000000000000 --- a/mamba/mamba_ssm/modules/mamba_simple.py +++ /dev/null @@ -1,418 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from einops import rearrange, repeat - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj -except ImportError: - selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None - -try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: - selective_state_update = None - -try: - from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -except ImportError: - RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - - -class Mamba(nn.Module): - def __init__( - self, - d_model, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, - bimamba=True, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank - self.use_fast_path = use_fast_path - self.layer_idx = layer_idx - self.bimamba = bimamba - - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - - self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.activation = "silu" - self.act = nn.SiLU() - - self.x_proj = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - # Initialize special dt projection to preserve variance at initialization - dt_init_std = self.dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(self.dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - self.dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - self.dt_proj.bias._no_reinit = True - - # S4D real initialization - # NOTE: why plus 1? - A = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D._no_weight_decay = True - - # bidirectional - # forked from https://github.com/hustvl/Vim - if self.bimamba: - A_b = repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_b_log = torch.log(A_b) # Keep A_b_log in fp32 - self.A_b_log = nn.Parameter(A_b_log) - self.A_b_log._no_weight_decay = True - - self.conv1d_b = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, - bias=conv_bias, - kernel_size=d_conv, - groups=self.d_inner, - padding=d_conv - 1, - **factory_kwargs, - ) - - self.x_proj_b = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs - ) - self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) - - self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 - self.D_b._no_weight_decay = True - - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - def forward(self, hidden_states, inference_params=None, T=1): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ - batch, seqlen, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out - - # We do matmul and transpose BLH -> HBL at the same time - # NOTE: same as in_proj(hidden_states) but memory-efficient with the following operations - xz = rearrange( - self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - if self.in_proj.bias is not None: - xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states - if self.bimamba: - A_b = -torch.exp(self.A_b_log.float()) - out = mamba_inner_fn_no_out_proj( - xz, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - out_b = mamba_inner_fn_no_out_proj( - xz.flip([-1]), - self.conv1d_b.weight, - self.conv1d_b.bias, - self.x_proj_b.weight, - self.dt_proj_b.weight, - A_b, - None, - None, - self.D_b.float(), - delta_bias=self.dt_proj_b.bias.float(), - delta_softplus=True, - ) - out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) - else: - out = mamba_inner_fn( - xz, - self.conv1d.weight, - self.conv1d.bias, - self.x_proj.weight, - self.dt_proj.weight, - self.out_proj.weight, - self.out_proj.bias, - A, - None, # input-dependent B - None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - ) - else: - x, z = xz.chunk(2, dim=1) - # Compute short convolution - if conv_state is not None: - conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - # We're careful here about the layout, to avoid extra transposes. - # We want dt to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - assert self.activation in ["silu", "swish"] - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=ssm_state is not None, - ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - return out - - def step(self, hidden_states, conv_state, ssm_state): - dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" - xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.chunk(2, dim=-1) # (B D) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - x = x + self.conv1d.bias - x = self.act(x).to(dtype=dtype) - else: - x = causal_conv1d_update( - x, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x_db = self.x_proj(x) # (B dt_rank+2*d_state) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - # Don't add dt_bias here - dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - # SSM step - if selective_state_update is None: - # Discretize A and B - dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - dB = torch.einsum("bd,bn->bdn", dt, B) - ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) - y = y + self.D.to(dtype) * x - y = y * self.act(z) # (B D) - else: - y = selective_state_update( - ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - - out = self.out_proj(y) - return out.unsqueeze(1), conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype - ) - ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_state, - device=self.dt_proj.weight.device, - dtype=self.dt_proj.weight.dtype, - # dtype=torch.float32, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - - -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/mamba/mamba_ssm/ops/__init__.py b/mamba/mamba_ssm/ops/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mamba/mamba_ssm/ops/selective_scan_interface.py b/mamba/mamba_ssm/ops/selective_scan_interface.py deleted file mode 100644 index 99b455ed949c123bb453922d5ac88d00f401e392..0000000000000000000000000000000000000000 --- a/mamba/mamba_ssm/ops/selective_scan_interface.py +++ /dev/null @@ -1,709 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Albert Gu. - -import torch -import torch.nn.functional as F -from torch.cuda.amp import custom_bwd, custom_fwd - -from einops import rearrange, repeat - -from causal_conv1d import causal_conv1d_fn -import causal_conv1d_cuda -import selective_scan_cuda - - -class SelectiveScanFn(torch.autograd.Function): - - @staticmethod - def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): - if u.stride(-1) != 1: - u = u.contiguous() - if delta.stride(-1) != 1: - delta = delta.contiguous() - if D is not None: - D = D.contiguous() - if B.stride(-1) != 1: - B = B.contiguous() - if C.stride(-1) != 1: - C = C.contiguous() - if z is not None and z.stride(-1) != 1: - z = z.contiguous() - if B.dim() == 3: - B = rearrange(B, "b dstate l -> b 1 dstate l") - ctx.squeeze_B = True - if C.dim() == 3: - C = rearrange(C, "b dstate l -> b 1 dstate l") - ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) - ctx.delta_softplus = delta_softplus - ctx.has_z = z is not None - last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) - if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) - return out if not return_last_state else (out, last_state) - else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) - out_z = rest[0] - return out_z if not return_last_state else (out_z, last_state) - - @staticmethod - def backward(ctx, dout, *args): - if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors - z = None - out = None - else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors - if dout.stride(-1) != 1: - dout = dout.contiguous() - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan_cuda with the backward of chunk). - # Here we just pass in None and dz will be allocated in the C++ code. - du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False # option to recompute out_z, not used here - ) - dz = rest[0] if ctx.has_z else None - dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB - dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC - return (du, ddelta, dA, dB, dC, - dD if D is not None else None, - dz, - ddelta_bias if delta_bias is not None else None, - None, - None) - - -def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): - """if return_last_state is True, returns (out, last_state) - last_state has shape (batch, dim, dstate). Note that the gradient of the last state is - not considered in the backward pass. - """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) - - -def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): - """ - u: r(B D L) - delta: r(B D L) - A: c(D N) or r(D N) - B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) - C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - - out: r(B D L) - last_state (optional): r(B D dstate) or c(B D dstate) - """ - dtype_in = u.dtype - u = u.float() - delta = delta.float() - if delta_bias is not None: - delta = delta + delta_bias[..., None].float() - if delta_softplus: - delta = F.softplus(delta) - batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] - is_variable_B = B.dim() >= 3 - is_variable_C = C.dim() >= 3 - if A.is_complex(): - if is_variable_B: - B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) - if is_variable_C: - C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) - else: - B = B.float() - C = C.float() - x = A.new_zeros((batch, dim, dstate)) - ys = [] - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - if not is_variable_B: - deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) - else: - if B.dim() == 3: - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) - else: - B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) - if is_variable_C and C.dim() == 4: - C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) - last_state = None - for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - if not is_variable_C: - y = torch.einsum('bdn,dn->bd', x, C) - else: - if C.dim() == 3: - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) - else: - y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) - if i == u.shape[2] - 1: - last_state = x - if y.is_complex(): - y = y.real * 2 - ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) - out = y if D is None else y + u * rearrange(D, "d -> d 1") - if z is not None: - out = out * F.silu(z) - out = out.to(dtype=dtype_in) - return out if not return_last_state else (out, last_state) - - -class MambaInnerFnNoOutProj(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): - """ - xz: (batch, dim, seqlen) - """ - assert checkpoint_lvl in [0, 1] - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - if torch.is_autocast_enabled(): - x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - if xz.stride(-1) != 1: - xz = xz.contiguous() - conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") - x, z = xz.chunk(2, dim=1) - conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) - # We're being very careful here about the layout, to avoid extra transposes. - # We want delta to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) - ctx.is_variable_B = B is None - ctx.is_variable_C = C is None - ctx.B_proj_bias_is_None = B_proj_bias is None - ctx.C_proj_bias_is_None = C_proj_bias is None - if B is None: # variable B - B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) - if B_proj_bias is not None: - B = B + B_proj_bias.to(dtype=B.dtype) - if not A.is_complex(): - # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() - B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - else: - B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() - else: - if B.stride(-1) != 1: - B = B.contiguous() - if C is None: # variable C - C = x_dbl[:, -d_state:] # (bl dstate) - if C_proj_bias is not None: - C = C + C_proj_bias.to(dtype=C.dtype) - if not A.is_complex(): - # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() - C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - else: - C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() - else: - if C.stride(-1) != 1: - C = C.contiguous() - if D is not None: - D = D.contiguous() - out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus - ) - ctx.delta_softplus = delta_softplus - ctx.checkpoint_lvl = checkpoint_lvl - if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass - conv1d_out, delta = None, None - ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, - delta_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out) - # return rearrange(out_z, "b d l -> b l d") - return out_z - - @staticmethod - @custom_bwd - def backward(ctx, dout): - # dout: (batch, seqlen, dim) - (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - x, z = xz.chunk(2, dim=1) - if dout.stride(-1) != 1: - dout = dout.contiguous() - if ctx.checkpoint_lvl == 1: - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), - "d (b l) -> b d l", l = L) - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan_cuda with the backward of chunk). - dxz = torch.empty_like(xz) # (batch, dim, seqlen) - dx, dz = dxz.chunk(2, dim=1) - # dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l - dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz, - ctx.delta_softplus, - True # option to recompute out_z - ) - dD = dD if D is not None else None - dx_dbl = torch.empty_like(x_dbl) - dB_proj_bias = None - if ctx.is_variable_B: - if not A.is_complex(): - dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() - else: - dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() - dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None - dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) - dB = None - dC_proj_bias = None - if ctx.is_variable_C: - if not A.is_complex(): - dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() - else: - dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() - dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None - dx_dbl[:, -d_state:] = dC # (bl d) - dC = None - ddelta = rearrange(ddelta, "b d l -> d (b l)") - ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) - dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) - dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") - dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) - dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) - dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( - x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True - ) - dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None - dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") - return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, - dA, dB, dC, dD, - ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None) - - -class MambaInnerFn(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): - """ - xz: (batch, dim, seqlen) - """ - assert checkpoint_lvl in [0, 1] - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - if torch.is_autocast_enabled(): - x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) - if out_proj_bias is not None else None) - if xz.stride(-1) != 1: - xz = xz.contiguous() - conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") - x, z = xz.chunk(2, dim=1) - conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) - # We're being very careful here about the layout, to avoid extra transposes. - # We want delta to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) - ctx.is_variable_B = B is None - ctx.is_variable_C = C is None - ctx.B_proj_bias_is_None = B_proj_bias is None - ctx.C_proj_bias_is_None = C_proj_bias is None - if B is None: # variable B - B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) - if B_proj_bias is not None: - B = B + B_proj_bias.to(dtype=B.dtype) - if not A.is_complex(): - # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() - B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - else: - B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() - else: - if B.stride(-1) != 1: - B = B.contiguous() - if C is None: # variable C - C = x_dbl[:, -d_state:] # (bl dstate) - if C_proj_bias is not None: - C = C + C_proj_bias.to(dtype=C.dtype) - if not A.is_complex(): - # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() - C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - else: - C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() - else: - if C.stride(-1) != 1: - C = C.contiguous() - if D is not None: - D = D.contiguous() - out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus - ) - ctx.delta_softplus = delta_softplus - ctx.out_proj_bias_is_None = out_proj_bias is None - ctx.checkpoint_lvl = checkpoint_lvl - if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass - conv1d_out, delta = None, None - ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, - delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out) - return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) - - @staticmethod - @custom_bwd - def backward(ctx, dout): - # dout: (batch, seqlen, dim) - (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - x, z = xz.chunk(2, dim=1) - if dout.stride(-1) != 1: - dout = dout.contiguous() - if ctx.checkpoint_lvl == 1: - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), - "d (b l) -> b d l", l = L) - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan_cuda with the backward of chunk). - dxz = torch.empty_like(xz) # (batch, dim, seqlen) - dx, dz = dxz.chunk(2, dim=1) - dout = rearrange(dout, "b l e -> e (b l)") - dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) - dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, - ctx.delta_softplus, - True # option to recompute out_z - ) - dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) - dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None - dD = dD if D is not None else None - dx_dbl = torch.empty_like(x_dbl) - dB_proj_bias = None - if ctx.is_variable_B: - if not A.is_complex(): - dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() - else: - dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() - dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None - dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) - dB = None - dC_proj_bias = None - if ctx.is_variable_C: - if not A.is_complex(): - dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() - else: - dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() - dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None - dx_dbl[:, -d_state:] = dC # (bl d) - dC = None - ddelta = rearrange(ddelta, "b d l -> d (b l)") - ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) - dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) - dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") - dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) - dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) - dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( - x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True - ) - dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None - dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") - return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, - dout_proj_weight, dout_proj_bias, - dA, dB, dC, dD, - ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None) - - -class BiMambaInnerFn(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): - """ - xz: (batch, dim, seqlen) - """ - assert checkpoint_lvl in [0, 1] - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - if torch.is_autocast_enabled(): - x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) - if out_proj_bias is not None else None) - if xz.stride(-1) != 1: - xz = xz.contiguous() - conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") - x, z = xz.chunk(2, dim=1) - conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) - # We're being very careful here about the layout, to avoid extra transposes. - # We want delta to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) - ctx.is_variable_B = B is None - ctx.is_variable_C = C is None - ctx.B_proj_bias_is_None = B_proj_bias is None - ctx.C_proj_bias_is_None = C_proj_bias is None - if B is None: # variable B - B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) - if B_proj_bias is not None: - B = B + B_proj_bias.to(dtype=B.dtype) - if not A.is_complex(): - # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() - B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - else: - B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() - else: - if B.stride(-1) != 1: - B = B.contiguous() - if C is None: # variable C - C = x_dbl[:, -d_state:] # (bl dstate) - if C_proj_bias is not None: - C = C + C_proj_bias.to(dtype=C.dtype) - if not A.is_complex(): - # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() - C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - else: - C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() - else: - if C.stride(-1) != 1: - C = C.contiguous() - if D is not None: - D = D.contiguous() - out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus - ) - assert not A_b.is_complex(), "A should not be complex!!" - out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( - conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, - ) - - out_z = out_z_f + out_z_b.flip([-1]) - - ctx.delta_softplus = delta_softplus - ctx.out_proj_bias_is_None = out_proj_bias is None - ctx.checkpoint_lvl = checkpoint_lvl - if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass - conv1d_out, delta = None, None - ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, - delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) - return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) - - @staticmethod - @custom_bwd - def backward(ctx, dout): - # dout: (batch, seqlen, dim) - (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) = ctx.saved_tensors - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - x, z = xz.chunk(2, dim=1) - if dout.stride(-1) != 1: - dout = dout.contiguous() - if ctx.checkpoint_lvl == 1: - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), - "d (b l) -> b d l", l = L) - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the - # backward of selective_scan_cuda with the backward of chunk). - dxz = torch.empty_like(xz) # (batch, dim, seqlen) - dx, dz = dxz.chunk(2, dim=1) - dout = rearrange(dout, "b l e -> e (b l)") - dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) - dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz, - ctx.delta_softplus, - True # option to recompute out_z - ) - # flip one - dz_b = torch.empty_like(dz) - dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd( - conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b, - ctx.delta_softplus, - True # option to recompute out_z - ) - - dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1]) - ddelta = ddelta + ddelta_f_b.flip([-1]) - dB = dB + dB_f_b.flip([-1]) - dC = dC + dC_f_b.flip([-1]) - dD = dD + dD_b - ddelta_bias = ddelta_bias + ddelta_bias_b - dz = dz + dz_b.flip([-1]) - out_z = out_z_f + out_z_b.flip([-1]) - - dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) - dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None - dD = dD if D is not None else None - dx_dbl = torch.empty_like(x_dbl) - dB_proj_bias = None - if ctx.is_variable_B: - if not A.is_complex(): - dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() - else: - dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() - dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None - dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) - dB = None - dC_proj_bias = None - if ctx.is_variable_C: - if not A.is_complex(): - dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() - else: - dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() - dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None - dx_dbl[:, -d_state:] = dC # (bl d) - dC = None - ddelta = rearrange(ddelta, "b d l -> d (b l)") - ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) - dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) - dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") - dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) - dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) - dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) - # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the - # backward of conv1d with the backward of chunk). - dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( - x, conv1d_weight, conv1d_bias, dconv1d_out, dx, True - ) - dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None - dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") - return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, - dout_proj_weight, dout_proj_bias, - dA, dA_b, dB, dC, dD, - ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None) - - -def mamba_inner_fn( - xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True -): - return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) - -def bimamba_inner_fn( - xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True -): - return BiMambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) - - -def mamba_inner_fn_no_out_proj( - xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True -): - return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) - - -def mamba_inner_ref( - xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True -): - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - x, z = xz.chunk(2, dim=1) - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") - # We're being very careful here about the layout, to avoid extra transposes. - # We want delta to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) - delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() - delta = rearrange(delta, "d (b l) -> b d l", l=L) - if B is None: # variable B - B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) - if B_proj_bias is not None: - B = B + B_proj_bias.to(dtype=B.dtype) - if not A.is_complex(): - B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() - else: - B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - if C is None: # variable B - C = x_dbl[:, -d_state:] # (bl d) - if C_proj_bias is not None: - C = C + C_proj_bias.to(dtype=C.dtype) - if not A.is_complex(): - C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() - else: - C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) - return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) - - -def bimamba_inner_ref( - xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True -): - L = xz.shape[-1] - delta_rank = delta_proj_weight.shape[1] - d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - x, z = xz.chunk(2, dim=1) - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") - # We're being very careful here about the layout, to avoid extra transposes. - # We want delta to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) - delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() - delta = rearrange(delta, "d (b l) -> b d l", l=L) - if B is None: # variable B - B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) - if B_proj_bias is not None: - B = B + B_proj_bias.to(dtype=B.dtype) - if not A.is_complex(): - B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() - else: - B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - if C is None: # variable B - C = x_dbl[:, -d_state:] # (bl d) - if C_proj_bias is not None: - C = C + C_proj_bias.to(dtype=C.dtype) - if not A.is_complex(): - C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() - else: - C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) - y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True) - y = y + y_b.flip([-1]) - return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/mamba/mamba_ssm/ops/triton/__init__.py b/mamba/mamba_ssm/ops/triton/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mamba/mamba_ssm/ops/triton/layernorm.py b/mamba/mamba_ssm/ops/triton/layernorm.py deleted file mode 100644 index 70d57397f6e5af1138e8df62629f9ab57174f6f4..0000000000000000000000000000000000000000 --- a/mamba/mamba_ssm/ops/triton/layernorm.py +++ /dev/null @@ -1,636 +0,0 @@ -# Copyright (c) 2023, Tri Dao. -# Implement residual + layer_norm / rms_norm. - -# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. -# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. -# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. - -import math - -import torch -import torch.nn.functional as F -from torch.cuda.amp import custom_fwd, custom_bwd - -import triton -import triton.language as tl - - -def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): - dtype = x.dtype - if upcast: - weight = weight.float() - bias = bias.float() if bias is not None else None - if upcast: - x = x.float() - residual = residual.float() if residual is not None else residual - if residual is not None: - x = (x + residual).to(x.dtype) - out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( - dtype - ) - return out if not prenorm else (out, x) - - -def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): - dtype = x.dtype - if upcast: - weight = weight.float() - bias = bias.float() if bias is not None else None - if upcast: - x = x.float() - residual = residual.float() if residual is not None else residual - if residual is not None: - x = (x + residual).to(x.dtype) - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) - out = out.to(dtype) - return out if not prenorm else (out, x) - - -@triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], - key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], -) -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - RESIDUAL, # pointer to the residual - RESIDUAL_OUT, # pointer to the residual - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_res_row, - stride_res_out_row, - N, # number of columns in X - eps, # epsilon to avoid division by zero - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, - STORE_RESIDUAL_OUT: tl.constexpr, - HAS_BIAS: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - X += row * stride_x_row - Y += row * stride_y_row - if HAS_RESIDUAL: - RESIDUAL += row * stride_res_row - if STORE_RESIDUAL_OUT: - RESIDUAL_OUT += row * stride_res_out_row - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_RESIDUAL: - residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) - x += residual - if STORE_RESIDUAL_OUT: - tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - # Write output - tl.store(Y + cols, y, mask=mask) - - -def _layer_norm_fwd( - x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False -): - if residual is not None: - residual_dtype = residual.dtype - M, N = x.shape - assert x.stride(-1) == 1 - if residual is not None: - assert residual.stride(-1) == 1 - assert residual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - # allocate output - y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - assert y.stride(-1) == 1 - if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): - residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) - assert residual_out.stride(-1) == 1 - else: - residual_out = None - mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device="cuda") - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - with torch.cuda.device(x.device.index): - _layer_norm_fwd_1pass_kernel[(M,)]( - x, - y, - weight, - bias, - residual, - residual_out, - mean, - rstd, - x.stride(0), - y.stride(0), - residual.stride(0) if residual is not None else 0, - residual_out.stride(0) if residual_out is not None else 0, - N, - eps, - is_rms_norm, - BLOCK_N, - residual is not None, - residual_out is not None, - bias is not None, - ) - # residual_out is None if residual is None and residual_dtype == input_dtype - return y, mean, rstd, residual_out if residual_out is not None else x - - -@triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], - key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], -) -# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) -# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) -@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) -@triton.jit -def _layer_norm_bwd_kernel( - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Y, # pointer to the output to be recomputed - DY, # pointer to the output gradient - DX, # pointer to the input gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - DRESIDUAL, - DRESIDUAL_IN, - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_dy_row, - stride_dx_row, - stride_dres_row, - stride_dres_in_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - rows_per_program, - IS_RMS_NORM: tl.constexpr, - BLOCK_N: tl.constexpr, - HAS_DRESIDUAL: tl.constexpr, - STORE_DRESIDUAL: tl.constexpr, - HAS_BIAS: tl.constexpr, - RECOMPUTE_OUTPUT: tl.constexpr, -): - # Map the program id to the elements of X, DX, and DY it should compute. - row_block_id = tl.program_id(0) - row_start = row_block_id * rows_per_program - cols = tl.arange(0, BLOCK_N) - mask = cols < N - X += row_start * stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += row_start * stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += row_start * stride_dres_in_row - DY += row_start * stride_dy_row - DX += row_start * stride_dx_row - if RECOMPUTE_OUTPUT: - Y += row_start * stride_y_row - w = tl.load(W + cols, mask=mask).to(tl.float32) - if RECOMPUTE_OUTPUT and HAS_BIAS: - b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) - dw = tl.zeros((BLOCK_N,), dtype=tl.float32) - if HAS_BIAS: - db = tl.zeros((BLOCK_N,), dtype=tl.float32) - row_end = min((row_block_id + 1) * rows_per_program, M) - for row in range(row_start, row_end): - # Load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - if not IS_RMS_NORM: - mean = tl.load(Mean + row) - rstd = tl.load(Rstd + row) - # Compute dx - xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - xhat = tl.where(mask, xhat, 0.0) - if RECOMPUTE_OUTPUT: - y = xhat * w + b if HAS_BIAS else xhat * w - tl.store(Y + cols, y, mask=mask) - wdy = w * dy - dw += dy * xhat - if HAS_BIAS: - db += dy - if not IS_RMS_NORM: - c1 = tl.sum(xhat * wdy, axis=0) / N - c2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * c1 + c2)) * rstd - else: - c1 = tl.sum(xhat * wdy, axis=0) / N - dx = (wdy - xhat * c1) * rstd - if HAS_DRESIDUAL: - dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) - dx += dres - # Write dx - if STORE_DRESIDUAL: - tl.store(DRESIDUAL_IN + cols, dx, mask=mask) - tl.store(DX + cols, dx, mask=mask) - - X += stride_x_row - if HAS_DRESIDUAL: - DRESIDUAL += stride_dres_row - if STORE_DRESIDUAL: - DRESIDUAL_IN += stride_dres_in_row - if RECOMPUTE_OUTPUT: - Y += stride_y_row - DY += stride_dy_row - DX += stride_dx_row - tl.store(DW + row_block_id * N + cols, dw, mask=mask) - if HAS_BIAS: - tl.store(DB + row_block_id * N + cols, db, mask=mask) - - -def _layer_norm_bwd( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - dresidual=None, - has_residual=False, - is_rms_norm=False, - x_dtype=None, - recompute_output=False, -): - M, N = x.shape - assert x.stride(-1) == 1 - assert dy.stride(-1) == 1 - assert dy.shape == (M, N) - if dresidual is not None: - assert dresidual.stride(-1) == 1 - assert dresidual.shape == (M, N) - assert weight.shape == (N,) - assert weight.stride(-1) == 1 - if bias is not None: - assert bias.stride(-1) == 1 - assert bias.shape == (N,) - # allocate output - dx = ( - torch.empty_like(x) - if x_dtype is None - else torch.empty(M, N, dtype=x_dtype, device=x.device) - ) - dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None - y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None - - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_N: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count - _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) - _db = ( - torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) - if bias is not None - else None - ) - rows_per_program = math.ceil(M / sm_count) - grid = (sm_count,) - with torch.cuda.device(x.device.index): - _layer_norm_bwd_kernel[grid]( - x, - weight, - bias, - y, - dy, - dx, - _dw, - _db, - dresidual, - dresidual_in, - mean, - rstd, - x.stride(0), - 0 if not recompute_output else y.stride(0), - dy.stride(0), - dx.stride(0), - dresidual.stride(0) if dresidual is not None else 0, - dresidual_in.stride(0) if dresidual_in is not None else 0, - M, - N, - eps, - rows_per_program, - is_rms_norm, - BLOCK_N, - dresidual is not None, - dresidual_in is not None, - bias is not None, - ) - dw = _dw.sum(0).to(weight.dtype) - db = _db.sum(0).to(bias.dtype) if bias is not None else None - # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype: - dresidual_in = dx - return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) - - -class LayerNormFn(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - weight, - bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if residual is not None: - assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() - weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - y, mean, rstd, residual_out = _layer_norm_fwd( - x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm - ) - ctx.save_for_backward(residual_out, weight, bias, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - y = y.reshape(x_shape_og) - return y if not prenorm else (y, residual_out.reshape(x_shape_og)) - - @staticmethod - def backward(ctx, dy, *args): - x, weight, bias, mean, rstd = ctx.saved_tensors - dy = dy.reshape(-1, dy.shape[-1]) - if dy.stride(-1) != 1: - dy = dy.contiguous() - assert dy.shape == x.shape - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dw, db, dresidual_in = _layer_norm_bwd( - dy, - x, - weight, - bias, - ctx.eps, - mean, - rstd, - dresidual, - ctx.has_residual, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - ) - return ( - dx.reshape(ctx.x_shape_og), - dw, - db, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - None, - None, - None, - None, - ) - - -def layer_norm_fn( - x, - weight, - bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, -): - return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) - - -def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): - return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) - - -class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) - - def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): - return rms_norm_fn( - x, - self.weight, - self.bias, - residual=residual, - eps=self.eps, - prenorm=prenorm, - residual_in_fp32=residual_in_fp32, - # is_rms_norm=True, - ) - - -class LayerNormLinearFn(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, - ): - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if residual is not None: - assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() - norm_weight = norm_weight.contiguous() - if norm_bias is not None: - norm_bias = norm_bias.contiguous() - residual_dtype = ( - residual.dtype - if residual is not None - else (torch.float32 if residual_in_fp32 else None) - ) - y, mean, rstd, residual_out = _layer_norm_fwd( - x, - norm_weight, - norm_bias, - eps, - residual, - out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), - residual_dtype=residual_dtype, - is_rms_norm=is_rms_norm, - ) - y = y.reshape(x_shape_og) - dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype - linear_weight = linear_weight.to(dtype) - linear_bias = linear_bias.to(dtype) if linear_bias is not None else None - out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) - # We don't store y, will be recomputed in the backward pass to save memory - ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) - ctx.x_shape_og = x_shape_og - ctx.eps = eps - ctx.is_rms_norm = is_rms_norm - ctx.has_residual = residual is not None - ctx.prenorm = prenorm - ctx.x_dtype = x.dtype - ctx.linear_bias_is_none = linear_bias is None - return out if not prenorm else (out, residual_out.reshape(x_shape_og)) - - @staticmethod - @custom_bwd - def backward(ctx, dout, *args): - x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors - dout = dout.reshape(-1, dout.shape[-1]) - dy = F.linear(dout, linear_weight.t()) - dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) - if dy.stride(-1) != 1: - dy = dy.contiguous() - assert dy.shape == x.shape - if ctx.prenorm: - dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() - assert dresidual.shape == x.shape - else: - dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( - dy, - x, - norm_weight, - norm_bias, - ctx.eps, - mean, - rstd, - dresidual, - ctx.has_residual, - ctx.is_rms_norm, - x_dtype=ctx.x_dtype, - recompute_output=True, - ) - dlinear_weight = torch.einsum("bo,bi->oi", dout, y) - return ( - dx.reshape(ctx.x_shape_og), - dnorm_weight, - dnorm_bias, - dlinear_weight, - dlinear_bias, - dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, - None, - None, - None, - None, - ) - - -def layer_norm_linear_fn( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual=None, - eps=1e-6, - prenorm=False, - residual_in_fp32=False, - is_rms_norm=False, -): - return LayerNormLinearFn.apply( - x, - norm_weight, - norm_bias, - linear_weight, - linear_bias, - residual, - eps, - prenorm, - residual_in_fp32, - is_rms_norm, - ) diff --git a/mamba/mamba_ssm/ops/triton/selective_state_update.py b/mamba/mamba_ssm/ops/triton/selective_state_update.py deleted file mode 100644 index fa95de73f173292914c5f00fbe9426937d00e502..0000000000000000000000000000000000000000 --- a/mamba/mamba_ssm/ops/triton/selective_state_update.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright (c) 2023, Tri Dao. - -"""We want triton==2.1.0 for this -""" - -import math -import torch -import torch.nn.functional as F - -import triton -import triton.language as tl - -from einops import rearrange, repeat - - -@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) -@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) -@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) -@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) -@triton.jit -def _selective_scan_update_kernel( - # Pointers to matrices - state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, - # Matrix dimensions - batch, dim, dstate, - # Strides - stride_state_batch, stride_state_dim, stride_state_dstate, - stride_x_batch, stride_x_dim, - stride_dt_batch, stride_dt_dim, - stride_dt_bias_dim, - stride_A_dim, stride_A_dstate, - stride_B_batch, stride_B_dstate, - stride_C_batch, stride_C_dstate, - stride_D_dim, - stride_z_batch, stride_z_dim, - stride_out_batch, stride_out_dim, - # Meta-parameters - DT_SOFTPLUS: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - HAS_DT_BIAS: tl.constexpr, - HAS_D: tl.constexpr, - HAS_Z: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_b = tl.program_id(axis=1) - state_ptr += pid_b * stride_state_batch - x_ptr += pid_b * stride_x_batch - dt_ptr += pid_b * stride_dt_batch - B_ptr += pid_b * stride_B_batch - C_ptr += pid_b * stride_C_batch - if HAS_Z: - z_ptr += pid_b * stride_z_batch - out_ptr += pid_b * stride_out_batch - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) - state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) - x_ptrs = x_ptr + offs_m * stride_x_dim - dt_ptrs = dt_ptr + offs_m * stride_dt_dim - if HAS_DT_BIAS: - dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim - A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) - B_ptrs = B_ptr + offs_n * stride_B_dstate - C_ptrs = C_ptr + offs_n * stride_C_dstate - if HAS_D: - D_ptrs = D_ptr + offs_m * stride_D_dim - if HAS_Z: - z_ptrs = z_ptr + offs_m * stride_z_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - - state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) - x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if DT_SOFTPLUS: - dt = tl.log(1.0 + tl.exp(dt)) - A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) - dA = tl.exp(A * dt[:, None]) - B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - if HAS_D: - D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_Z: - z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - - dB = B[None, :] * dt[:, None] - state = state * dA + dB * x[:, None] - tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) - out = tl.sum(state * C[None, :], axis=1) - if HAS_D: - out += x * D - if HAS_Z: - out *= z * tl.sigmoid(z) - tl.store(out_ptrs, out, mask=offs_m < dim) - - -def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): - """ - Argument: - state: (batch, dim, dstate) - x: (batch, dim) - dt: (batch, dim) - A: (dim, dstate) - B: (batch, dstate) - C: (batch, dstate) - D: (dim,) - z: (batch, dim) - dt_bias: (dim,) - Return: - out: (batch, dim) - """ - batch, dim, dstate = state.shape - assert x.shape == (batch, dim) - assert dt.shape == x.shape - assert A.shape == (dim, dstate) - assert B.shape == (batch, dstate) - assert C.shape == B.shape - if D is not None: - assert D.shape == (dim,) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert dt_bias.shape == (dim,) - out = torch.empty_like(x) - grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) - z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) - # We don't want autotune since it will overwrite the state - # We instead tune by hand. - BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 - else ((16, 4) if dstate <= 32 else - ((8, 4) if dstate <= 64 else - ((4, 4) if dstate <= 128 else - ((4, 8)))))) - with torch.cuda.device(x.device.index): - _selective_scan_update_kernel[grid]( - state, x, dt, dt_bias, A, B, C, D, z, out, - batch, dim, dstate, - state.stride(0), state.stride(1), state.stride(2), - x.stride(0), x.stride(1), - dt.stride(0), dt.stride(1), - dt_bias.stride(0) if dt_bias is not None else 0, - A.stride(0), A.stride(1), - B.stride(0), B.stride(1), - C.stride(0), C.stride(1), - D.stride(0) if D is not None else 0, - z_strides[0], z_strides[1], - out.stride(0), out.stride(1), - dt_softplus, - BLOCK_SIZE_M, - num_warps=num_warps, - ) - return out - - -def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): - """ - Argument: - state: (batch, dim, dstate) - x: (batch, dim) - dt: (batch, dim) - A: (dim, dstate) - B: (batch, dstate) - C: (batch, dstate) - D: (dim,) - z: (batch, dim) - dt_bias: (dim,) - Return: - out: (batch, dim) - """ - batch, dim, dstate = state.shape - assert x.shape == (batch, dim) - assert dt.shape == x.shape - assert A.shape == (dim, dstate) - assert B.shape == (batch, dstate) - assert C.shape == B.shape - if D is not None: - assert D.shape == (dim,) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert dt_bias.shape == (dim,) - dt = dt + dt_bias - dt = F.softplus(dt) if dt_softplus else dt - dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) - dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) - state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate - out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) - if D is not None: - out += (x * D).to(out.dtype) - return (out if z is None else out * F.silu(z)).to(x.dtype) diff --git a/mamba/mamba_ssm/utils/__init__.py b/mamba/mamba_ssm/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/mamba/mamba_ssm/utils/generation.py b/mamba/mamba_ssm/utils/generation.py deleted file mode 100644 index 9d766b29ac28a388a7d77b22aa2cb1eda733c0f4..0000000000000000000000000000000000000000 --- a/mamba/mamba_ssm/utils/generation.py +++ /dev/null @@ -1,377 +0,0 @@ -# Copyright (c) 2023, Albert Gu, Tri Dao. -import gc -import time -from collections import namedtuple -from dataclasses import dataclass, field -from functools import partial -from typing import Callable, Optional, Sequence, Union - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from torch import Tensor -from torch.profiler import ProfilerActivity, profile, record_function -from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput - - -@dataclass -class InferenceParams: - """Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference.""" - - max_seqlen: int - max_batch_size: int - seqlen_offset: int = 0 - batch_size_offset: int = 0 - key_value_memory_dict: dict = field(default_factory=dict) - lengths_per_sample: Optional[Tensor] = None - - def reset(self, max_seqlen, max_batch_size): - self.max_seqlen = max_seqlen - self.max_batch_size = max_batch_size - self.seqlen_offset = 0 - if self.lengths_per_sample is not None: - self.lengths_per_sample.zero_() - - -# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py -# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 -def modify_logits_for_top_k_filtering(logits, top_k): - """Set the logits for none top-k values to -inf. Done in-place.""" - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits.masked_fill_(indices_to_remove, float("-Inf")) - - -# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py -# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 -def modify_logits_for_top_p_filtering(logits, top_p): - """Set the logits for none top-p values to -inf. Done in-place.""" - if top_p <= 0.0 or top_p >= 1.0: - return - # First sort and calculate cumulative sum of probabilities. - sorted_logits, sorted_indices = torch.sort(logits, descending=False) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) - # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs <= (1 - top_p) - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) - logits.masked_fill_(indices_to_remove, float("-inf")) - - -def sample(logits, top_k=1, top_p=0.0, temperature=1.0): - """Sample from top-k logits. - Arguments: - logits: Tensor of shape (batch_size, vocab_size) - """ - if top_k == 1: # Short-circuit for greedy decoding - return logits.argmax(dim=-1) - else: - if top_p > 0.0: - assert top_p <= 1.0, "top-p should be in (0, 1]." - if top_k > 0: - top_k = min(top_k, logits.size(-1)) # Safety check - logits_top, indices = torch.topk(logits, top_k, dim=-1) - if temperature != 1.0: - logits_top /= temperature - modify_logits_for_top_p_filtering(logits_top, top_p) - return indices[ - torch.arange(indices.shape[0], device=indices.device), - torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), - ] - else: - # Clone so that when we modify for top_p we don't change the original logits - logits_top = logits / temperature if temperature != 1.0 else logits.clone() - modify_logits_for_top_p_filtering(logits_top, top_p) - return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( - dim=-1 - ) - - -@torch.inference_mode() -def decode( - input_ids, - model, - max_length, - top_k=1, - top_p=0.0, - temperature=1.0, - eos_token_id=None, - teacher_outputs=None, - vocab_size=None, - tensor_parallel=1, - cg=False, - enable_timing=False, -): - """Decoding, either greedy or with top-k or top-p sampling. - If top-k = 0, don't limit the number of candidates (pure sampling). - Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, - then top-p. - We assume that all sequences in the same batch have the same length. - - Arguments: - input_ids: (batch, seq_len) - max_length: int - teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the - logits, the next token is taken from the teacher_outputs. Useful for testing. - Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: - sequences: (batch, max_length) - scores: tuples of (batch, vocab_size) - """ - batch_size, seqlen_og = input_ids.shape - teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 - if cg: - if not hasattr(model, "_decoding_cache"): - model._decoding_cache = None - model._decoding_cache = update_graph_cache( - model, - model._decoding_cache, - batch_size, - seqlen_og, - max_length, - tensor_parallel=tensor_parallel, - ) - inference_params = model._decoding_cache.inference_params - inference_params.reset(max_length, batch_size) - else: - inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) - - def get_logits(input_ids, inference_params): - decoding = inference_params.seqlen_offset > 0 - if decoding: - position_ids = torch.full( - (batch_size, 1), - inference_params.seqlen_offset, - dtype=torch.long, - device=input_ids.device, - ) - else: - position_ids = None - if not cg or not decoding: - logits = model( - input_ids, - position_ids=position_ids, - inference_params=inference_params, - num_last_tokens=1, - ).logits.squeeze(dim=1) - else: - logits = model._decoding_cache.run( - input_ids, position_ids, inference_params.seqlen_offset - ).squeeze(dim=1) - return logits[..., :vocab_size] if vocab_size is not None else logits - - def sample_tokens(logits, inference_params): - if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: - token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) - else: - token = teacher_outputs[:, inference_params.seqlen_offset] - # return rearrange(token, "b -> b 1") - return token.unsqueeze(1) - - def should_stop(current_token, inference_params): - if inference_params.seqlen_offset == 0: - return False - if eos_token_id is not None and (current_token == eos_token_id).all(): - return True - if inference_params.seqlen_offset >= max_length - 1: - return True - return False - - start = torch.cuda.Event(enable_timing=enable_timing) - end = torch.cuda.Event(enable_timing=enable_timing) - - if enable_timing: - if tensor_parallel > 1: - torch.distributed.barrier() - start.record() - scores, sequences = [], [input_ids] - while not should_stop(sequences[-1], inference_params): - scores.append(get_logits(sequences[-1], inference_params)) - inference_params.seqlen_offset += sequences[-1].shape[1] - sequences.append(sample_tokens(scores[-1], inference_params)) - if enable_timing: - end.record() - if tensor_parallel > 1: - torch.distributed.barrier() - torch.cuda.synchronize() - print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") - output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput - return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) - - -class GenerationMixin: - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - raise NotImplementedError - - def generate( - self, - input_ids, - max_length, - top_k=1, - top_p=0.0, - temperature=1.0, - return_dict_in_generate=False, - output_scores=False, - **kwargs, - ): - output = decode( - input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs - ) - if not output_scores: - output.scores = None - return output if return_dict_in_generate else output.sequences - - -def allocate_inference_cache( - max_batch_size, - max_seqlen, - nheads, - headdim, - layers: Union[int, Sequence], - device, - dtype=torch.float16, -): - assert dtype in [torch.float16, torch.bfloat16, torch.float32] - kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) - if isinstance(layers, int): - layers = range(layers) - return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} - - -@dataclass -class DecodingCGCache: - max_batch_size: int = 0 - max_seqlen: int = 0 - device = None - dtype = None - callables: dict = field(default_factory=dict) - mempool = None - inference_params: Optional[InferenceParams] = None - run: Optional[Callable] = None - - -@torch.inference_mode() -def update_graph_cache( - model, - cache, - batch_size, - seqlen_og, - max_seqlen, - decoding_seqlens=(1,), - tensor_parallel=1, - dtype=None, - n_warmups=2, -): - if cache is None: - cache = DecodingCGCache() - param_example = next(iter(model.parameters())) - device = param_example.device - if dtype is None: - dtype = param_example.dtype - if ( - (device, dtype) != (cache.device, cache.dtype) - or batch_size > cache.max_batch_size - or max_seqlen > cache.max_seqlen - ): # Invalidate the cache - cache.callables = {} - cache.mempool = None - cache.inference_params = None - gc.collect() - cache.device, cache.dtype = device, dtype - cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen - if hasattr(model, "allocate_inference_cache"): - inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) - else: - headdim = getattr( - model.config, - "head_dim", - model.config.hidden_size // model.config.num_attention_heads, - ) - inf_cache = allocate_inference_cache( - batch_size, - max_seqlen, - model.config.num_attention_heads // tensor_parallel, - headdim, - model.config.num_hidden_layers, - device, - dtype, - ) - lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) - cache.inference_params = InferenceParams( - max_seqlen=max_seqlen, - max_batch_size=batch_size, - seqlen_offset=seqlen_og, - key_value_memory_dict=inf_cache, - lengths_per_sample=lengths_per_sample, - ) - cache.mempool = torch.cuda.graphs.graph_pool_handle() - for decoding_seqlen in decoding_seqlens: - if (batch_size, decoding_seqlen) not in cache.callables: - cache.callables[batch_size, decoding_seqlen] = capture_graph( - model, - cache.inference_params, - batch_size, - max_seqlen, - decoding_seqlen=decoding_seqlen, - mempool=cache.mempool, - n_warmups=n_warmups, - ) - - def dispatch(input_ids, position_ids, seqlen): - batch_size, decoding_seqlen = input_ids.shape[:2] - return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) - - cache.run = dispatch - cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing - return cache - - -def capture_graph( - model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 -): - device = next(iter(model.parameters())).device - input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) - position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) - seqlen_offset_og = inference_params.seqlen_offset - inference_params.seqlen_offset = max_seqlen - decoding_seqlen - inference_params.lengths_per_sample[:] = inference_params.seqlen_offset - - # Warmup before capture - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(n_warmups): - logits = model( - input_ids, - position_ids=position_ids, - inference_params=inference_params, - num_last_tokens=decoding_seqlen, - ).logits - s.synchronize() - # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, - # which requires that graph launch and non-captured launch to not overlap (I think, - # that's how I interpret the documentation). I'm not sure if this is required. - if torch.distributed.is_initialized(): - torch.distributed.barrier() - torch.cuda.current_stream().wait_stream(s) - # Captures the graph - # To allow capture, automatically sets a side stream as the current stream in the context - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, pool=mempool): - logits = model( - input_ids, - position_ids=position_ids, - inference_params=inference_params, - num_last_tokens=decoding_seqlen, - ).logits - - def run(new_input_ids, new_position_ids, seqlen): - inference_params.lengths_per_sample[:] = seqlen - input_ids.copy_(new_input_ids) - position_ids.copy_(new_position_ids) - graph.replay() - return logits.clone() - - inference_params.seqlen_offset = seqlen_offset_og - return run diff --git a/mamba/mamba_ssm/utils/hf.py b/mamba/mamba_ssm/utils/hf.py deleted file mode 100644 index 0d7555acddbd260636d1d14d5bd6324f6af0056a..0000000000000000000000000000000000000000 --- a/mamba/mamba_ssm/utils/hf.py +++ /dev/null @@ -1,23 +0,0 @@ -import json - -import torch - -from transformers.utils import WEIGHTS_NAME, CONFIG_NAME -from transformers.utils.hub import cached_file - - -def load_config_hf(model_name): - resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) - return json.load(open(resolved_archive_file)) - - -def load_state_dict_hf(model_name, device=None, dtype=None): - # If not fp32, then we don't want to load directly to the GPU - mapped_device = "cpu" if dtype not in [torch.float32, None] else device - resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) - return torch.load(resolved_archive_file, map_location=mapped_device) - # Convert dtype before moving to GPU to save memory - if dtype is not None: - state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} - state_dict = {k: v.to(device=device) for k, v in state_dict.items()} - return state_dict diff --git a/mamba/setup.py b/mamba/setup.py deleted file mode 100644 index 2ce0ac045f8b2ae07f39f3d045e997ab362ec4c1..0000000000000000000000000000000000000000 --- a/mamba/setup.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) 2023, Albert Gu, Tri Dao. -import sys -import warnings -import os -import re -import ast -from pathlib import Path -from packaging.version import parse, Version -import platform -import shutil - -from setuptools import setup, find_packages -import subprocess - -import urllib.request -import urllib.error -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel - -import torch -from torch.utils.cpp_extension import ( - BuildExtension, - CppExtension, - CUDAExtension, - CUDA_HOME, -) - - -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - -PACKAGE_NAME = "mamba_ssm" - -BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" - -# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels -# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation -FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" -# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" - - -def get_platform(): - """ - Returns the platform name as used in wheel filenames. - """ - if sys.platform.startswith("linux"): - return "linux_x86_64" - elif sys.platform == "darwin": - mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) - return f"macosx_{mac_version}_x86_64" - elif sys.platform == "win32": - return "win_amd64" - else: - raise ValueError("Unsupported platform: {}".format(sys.platform)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True - ) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary - # in that case. - warnings.warn( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - return nvcc_extra_args + ["--threads", "4"] - - -cmdclass = {} -ext_modules = [] - -if not SKIP_CUDA_BUILD: - print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - - check_if_cuda_home_none(PACKAGE_NAME) - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - if CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version < Version("11.6"): - raise RuntimeError( - f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " - "Note: make sure nvcc has a supported version by running nvcc -V." - ) - - cc_flag.append("-gencode") - cc_flag.append("arch=compute_70,code=sm_70") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - - # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as - # torch._C._GLIBCXX_USE_CXX11_ABI - # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 - if FORCE_CXX11_ABI: - torch._C._GLIBCXX_USE_CXX11_ABI = True - - ext_modules.append( - CUDAExtension( - name="selective_scan_cuda", - sources=[ - "csrc/selective_scan/selective_scan.cpp", - "csrc/selective_scan/selective_scan_fwd_fp32.cu", - "csrc/selective_scan/selective_scan_fwd_fp16.cu", - "csrc/selective_scan/selective_scan_fwd_bf16.cu", - "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", - "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", - "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", - "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", - "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", - "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - "-lineinfo", - ] - + cc_flag - ), - }, - include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], - ) - ) - - -def get_package_version(): - with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f: - version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) - public_version = ast.literal_eval(version_match.group(1)) - local_version = os.environ.get("MAMBA_LOCAL_VERSION") - if local_version: - return f"{public_version}+{local_version}" - else: - return str(public_version) - - -def get_wheel_url(): - # Determine the version numbers that will be used to determine the correct wheel - # We're using the CUDA version used to build torch, not the one currently installed - # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - torch_cuda_version = parse(torch.version.cuda) - torch_version_raw = parse(torch.__version__) - # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 - # to save CI time. Minor versions should be compatible. - torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") - python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" - platform_name = get_platform() - mamba_ssm_version = get_package_version() - # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" - cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" - torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" - cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() - - # Determine wheel URL based on CUDA version, torch version, python version and OS - wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" - wheel_url = BASE_WHEEL_URL.format( - tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename - ) - return wheel_url, wheel_filename - - -class CachedWheelsCommand(_bdist_wheel): - """ - The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot - find an existing wheel (which is currently the case for all installs). We use - the environment parameters to detect whether there is already a pre-built version of a compatible - wheel available and short-circuits the standard full build pipeline. - """ - - def run(self): - if FORCE_BUILD: - return super().run() - - wheel_url, wheel_filename = get_wheel_url() - print("Guessing wheel URL: ", wheel_url) - try: - urllib.request.urlretrieve(wheel_url, wheel_filename) - - # Make the archive - # Lifted from the root wheel processing command - # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 - if not os.path.exists(self.dist_dir): - os.makedirs(self.dist_dir) - - impl_tag, abi_tag, plat_tag = self.get_tag() - archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - - wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") - print("Raw wheel path", wheel_path) - shutil.move(wheel_filename, wheel_path) - except urllib.error.HTTPError: - print("Precompiled wheel not found. Building from source...") - # If the wheel could not be downloaded, build from source - super().run() - - -setup( - name=PACKAGE_NAME, - version=get_package_version(), - packages=find_packages( - exclude=( - "build", - "csrc", - "include", - "tests", - "dist", - "docs", - "benchmarks", - "mamba_ssm.egg-info", - ) - ), - author="Tri Dao, Albert Gu", - author_email="tri@tridao.me, agu@cs.cmu.edu", - description="Mamba state-space model", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/state-spaces/mamba", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - ], - ext_modules=ext_modules, - cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} - if ext_modules - else { - "bdist_wheel": CachedWheelsCommand, - }, - python_requires=">=3.7", - install_requires=[ - "torch", - "packaging", - "ninja", - "einops", - "triton", - "transformers", - "causal_conv1d", - ], -) diff --git a/mamba/test_mamba_module.py b/mamba/test_mamba_module.py deleted file mode 100644 index 64710e92f7ec4fc0fe88821550e4ecf902a22bfe..0000000000000000000000000000000000000000 --- a/mamba/test_mamba_module.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -from mamba_ssm import Mamba - -batch, length, dim = 2, 64, 768 -x = torch.randn(batch, length, dim).to("cuda") -model = Mamba( - # This module uses roughly 3 * expand * d_model^2 parameters - d_model=dim, # Model dimension d_model - d_state=16, # SSM state expansion factor # 64 - d_conv=4, # Local convolution width - expand=2, # Block expansion factor - use_fast_path=False, -).to("cuda") -y = model(x) -assert y.shape == x.shape diff --git a/mamba/tests/ops/test_selective_scan.py b/mamba/tests/ops/test_selective_scan.py deleted file mode 100644 index 26b34a37560f08ced653a1d9320a14f3d3f9ebd3..0000000000000000000000000000000000000000 --- a/mamba/tests/ops/test_selective_scan.py +++ /dev/null @@ -1,423 +0,0 @@ -# Copyright (C) 2023, Tri Dao. - -import math - -import torch -import torch.nn.functional as F -from torch.autograd import gradcheck -import pytest - -from einops import rearrange - -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref -from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref -from mamba_ssm.ops.selective_scan_interface import bimamba_inner_fn, bimamba_inner_ref - - -# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) -@pytest.mark.parametrize('wtype', [torch.float32]) -# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('itype', [torch.float32]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) -# @pytest.mark.parametrize('seqlen', [128]) -# @pytest.mark.parametrize("return_last_state", [False, True]) -@pytest.mark.parametrize("return_last_state", [True]) -# @pytest.mark.parametrize('has_delta_bias', [False, True]) -@pytest.mark.parametrize('has_delta_bias', [True]) -# @pytest.mark.parametrize('delta_softplus', [False, True]) -@pytest.mark.parametrize('delta_softplus', [True]) -# @pytest.mark.parametrize('has_z', [False, True]) -@pytest.mark.parametrize('has_z', [True]) -# @pytest.mark.parametrize('has_D', [False, True]) -@pytest.mark.parametrize('has_D', [True]) -@pytest.mark.parametrize("varBC_groups", [1, 2]) -# @pytest.mark.parametrize("varBC_groups", [1]) -# @pytest.mark.parametrize("is_variable_C", [False, True]) -@pytest.mark.parametrize("is_variable_C", [True]) -# @pytest.mark.parametrize("is_variable_B", [False, True]) -@pytest.mark.parametrize("is_variable_B", [True]) -def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, itype, wtype): - if varBC_groups > 1 and (not is_variable_B or not is_variable_C): - pytest.skip() # This config is not applicable - device = 'cuda' - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - if has_z: # If we have z, the errors on the weights seem higher - rtolw = max(rtolw, rtol) - atolw = max(atolw, atol) - # set seed - torch.random.manual_seed(0) - batch_size = 2 - dim = 4 - dstate = 8 - is_complex = wtype == torch.complex64 - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() - if not is_variable_B: - B_shape = (dim, dstate) - elif varBC_groups == 1: - B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) - else: - B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, - requires_grad=True) - if not is_variable_C: - C_shape = (dim, dstate) - elif varBC_groups == 1: - C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) - else: - C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, - requires_grad=True) - if has_D: - D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - else: - D = None - if has_z: - z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) - else: - z = None - if has_delta_bias: - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() - else: - delta_bias = None - u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) - delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() - A_ref = A.detach().clone().requires_grad_() - B_ref = B.detach().clone().requires_grad_() - C_ref = C.detach().clone().requires_grad_() - D_ref = D.detach().clone().requires_grad_() if D is not None else None - z_ref = z.detach().clone().requires_grad_() if z is not None else None - u_ref = u.detach().clone().requires_grad_() - delta_ref = delta.detach().clone().requires_grad_() - delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out, *rest = selective_scan_fn( - u, delta, A, B, C, D, z=z, - delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=return_last_state - ) - if return_last_state: - state = rest[0] - out_ref, *rest = selective_scan_ref( - u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, - delta_bias=delta_bias_ref, delta_softplus=delta_softplus, - return_last_state=return_last_state - ) - if return_last_state: - state_ref = rest[0] - # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - # dt_u = delta * u - - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - if return_last_state: - print(f'State max diff: {(state - state_ref).abs().max().item()}') - assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) - - g = torch.randn_like(out) - out_ref.backward(g) - out.backward(g) - - print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') - print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') - print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') - print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') - print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') - if has_D: - print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') - if has_z: - print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') - if has_delta_bias: - print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') - - assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) - assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) - assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) - assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, - atol=atolw if not is_variable_B else atol) - assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, - atol=atolw if not is_variable_C else atol) - if has_D: - assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) - if has_z: - assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) - if has_delta_bias: - assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) - - -@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) -# @pytest.mark.parametrize('wtype', [torch.complex64]) -# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('itype', [torch.float32]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -@pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize("is_variable_C", [False, True]) -# @pytest.mark.parametrize("is_variable_C", [False]) -@pytest.mark.parametrize("is_variable_B", [False, True]) -# @pytest.mark.parametrize("is_variable_B", [True]) -def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): - device = 'cuda' - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - # If we have z, the errors on the weights seem higher - rtolw = max(rtolw, rtol) - atolw = max(atolw, atol) - # set seed - torch.random.manual_seed(0) - batch_size = 2 - dim = 768 - dstate = 8 - dt_rank = 48 - is_complex = wtype == torch.complex64 - xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) - conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) - conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate - * (1 if not is_complex else 2), - dim, device=device, dtype=itype, requires_grad=True) - delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) - out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) - out_proj_bias = None - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() - B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) - if not is_variable_B else None) - C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) - if not is_variable_C else None) - D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() - B_proj_bias = None - C_proj_bias = None - xz_ref = xz.detach().clone().requires_grad_() - conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() - conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() - x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() - delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() - out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() - out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() - if out_proj_bias is not None else None) - A_ref = A.detach().clone().requires_grad_() - B_ref = B.detach().clone().requires_grad_() if B is not None else None - C_ref = C.detach().clone().requires_grad_() if C is not None else None - D_ref = D.detach().clone().requires_grad_() - delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias=delta_bias, delta_softplus=True) - out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref, - delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref, - A_ref, B_ref, C_ref, D_ref, - delta_bias=delta_bias_ref, delta_softplus=True) - # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - # dt_u = delta * u - print("mamba_inner_fn") - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - - g = torch.randn_like(out) - out_ref.backward(g) - out.backward(g) - - print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}') - print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') - if not is_variable_B: - print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') - if not is_variable_C: - print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') - print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') - print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') - print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}') - print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}') - print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}') - print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}') - print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}') - - # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) - # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) - # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) - # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, - # atol=atolw if not is_variable_B else atol) - # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, - # atol=atolw if not is_variable_C else atol) - # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) - # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) - - -# test_mamba_inner_fn(False, False, 128, torch.float32, torch.float32) - - -@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) -# @pytest.mark.parametrize('wtype', [torch.complex64]) -# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('itype', [torch.float32]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -@pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize("is_variable_C", [False, True]) -# @pytest.mark.parametrize("is_variable_C", [False]) -@pytest.mark.parametrize("is_variable_B", [False, True]) -# @pytest.mark.parametrize("is_variable_B", [True]) -def test_bimamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): - device = 'cuda' - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - # If we have z, the errors on the weights seem higher - rtolw = max(rtolw, rtol) - atolw = max(atolw, atol) - # set seed - torch.random.manual_seed(0) - batch_size = 2 - dim = 768 - dstate = 8 - dt_rank = 48 - is_complex = wtype == torch.complex64 - xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) - conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) - conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate - * (1 if not is_complex else 2), - dim, device=device, dtype=itype, requires_grad=True) - delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) - out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) - out_proj_bias = None - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() - A_b = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() - B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) - if not is_variable_B else None) - C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) - if not is_variable_C else None) - D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() - B_proj_bias = None - C_proj_bias = None - xz_ref = xz.detach().clone().requires_grad_() - conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() - conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() - x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() - delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() - out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() - out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() - if out_proj_bias is not None else None) - A_ref = A.detach().clone().requires_grad_() - A_b_ref = A_b.detach().clone().requires_grad_() - B_ref = B.detach().clone().requires_grad_() if B is not None else None - C_ref = C.detach().clone().requires_grad_() if C is not None else None - D_ref = D.detach().clone().requires_grad_() - delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out = bimamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, A_b, B, C, D, delta_bias=delta_bias, delta_softplus=True) - out_ref = bimamba_inner_fn(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref, - delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref, - A_ref, A_b_ref, B_ref, C_ref, D_ref, - delta_bias=delta_bias_ref, delta_softplus=True) - # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - # dt_u = delta * u - print("bimamba_inner_fn") - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - - g = torch.randn_like(out) - out_ref.backward(g) - out.backward(g) - - print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}') - print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') - print(f'dA_b max diff: {(A_b.grad - A_b_ref.grad).abs().max().item()}') - if not is_variable_B: - print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') - if not is_variable_C: - print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') - print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') - print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') - print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}') - print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}') - print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}') - print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}') - print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}') - -@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) -# @pytest.mark.parametrize('wtype', [torch.complex64]) -# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('itype', [torch.float32]) -# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) -@pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize("is_variable_C", [False, True]) -# @pytest.mark.parametrize("is_variable_C", [False]) -@pytest.mark.parametrize("is_variable_B", [False, True]) -# @pytest.mark.parametrize("is_variable_B", [True]) -def test_bimamba_inner_fn_grad_check(is_variable_B, is_variable_C, seqlen, itype, wtype): - device = 'cuda' - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - # If we have z, the errors on the weights seem higher - rtolw = max(rtolw, rtol) - atolw = max(atolw, atol) - # set seed - torch.random.manual_seed(0) - batch_size = 2 // 2 - dim = 768 // 8 - dstate = 8 // 8 - dt_rank = 48 // 8 - is_complex = wtype == torch.complex64 - xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) - conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) - conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate - * (1 if not is_complex else 2), - dim, device=device, dtype=itype, requires_grad=True) - delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) - out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) - out_proj_bias = None - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() - A_b = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() - B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) - if not is_variable_B else None) - C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) - if not is_variable_C else None) - D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() - B_proj_bias = None - C_proj_bias = None - xz_ref = xz.detach().clone().requires_grad_() - conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() - conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() - x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() - delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() - out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() - out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() - if out_proj_bias is not None else None) - A_ref = A.detach().clone().requires_grad_() - A_b_ref = A_b.detach().clone().requires_grad_() - B_ref = B.detach().clone().requires_grad_() if B is not None else None - C_ref = C.detach().clone().requires_grad_() if C is not None else None - D_ref = D.detach().clone().requires_grad_() - delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - - # func = bimamba_inner_fn - # func = mamba_inner_fn - func = mamba_inner_ref - - # gradok = gradcheck(func, (xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias, A, A_b, B, C, D, delta_bias, None, None, True)) - gradok = gradcheck(func, (xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, None, None, True), eps=1e-6, atol=1e-4, nondet_tol=1.) - print(f'* {gradok} check_gradient_numerical bimamba_inner_fn') - - - -# test_bimamba_inner_fn(True, True, 128, torch.float32, torch.float32) -# test_mamba_inner_fn(True, True, 128, torch.float32, torch.float32) -test_bimamba_inner_fn_grad_check(True, True, 128, torch.float32, torch.float32) - -# input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True)) -# test = gradcheck(torch.nn.functional.linear, input, eps=1e-6, atol=1e-4) -# print(test) \ No newline at end of file diff --git a/mamba/tests/ops/triton/test_selective_state_update.py b/mamba/tests/ops/triton/test_selective_state_update.py deleted file mode 100644 index 70a8d79d9cad3e4d33897478caf178bd96d0ae5a..0000000000000000000000000000000000000000 --- a/mamba/tests/ops/triton/test_selective_state_update.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (C) 2023, Tri Dao. - -import math - -import torch -import torch.nn.functional as F -import pytest - -from einops import rearrange - -from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref - - -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize('itype', [torch.float16]) -@pytest.mark.parametrize("has_z", [False, True]) -# @pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize("dstate", [16, 32, 64]) -# @pytest.mark.parametrize("dstate", [16]) -@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -# @pytest.mark.parametrize("dim", [2048]) -def test_causal_conv1d_update(dim, dstate, has_z, itype): - device = "cuda" - rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) - if itype == torch.bfloat16: - rtol, atol = 1e-2, 5e-2 - # set seed - torch.random.manual_seed(0) - batch_size = 2 - state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) - x = torch.randn(batch_size, dim, device=device, dtype=itype) - dt = torch.randn(batch_size, dim, device=device, dtype=itype) - dt_bias = torch.rand(dim, device=device) - 4.0 - A = -torch.rand(dim, dstate, device=device) - 1.0 - B = torch.randn(batch_size, dstate, device=device) - C = torch.randn(batch_size, dstate, device=device) - D = torch.randn(dim, device=device) - if has_z: - z = torch.randn_like(x) - else: - z = None - state_ref = state.detach().clone() - out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) - out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl b/mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..644884621c7f4f5c570309baa5e76da306e5afb2 --- /dev/null +++ b/mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96c19c64be06f0bb5607c6a2406fdf7289a0d83037ea005ec56a2e4ae5d40ec7 +size 145927203 diff --git a/requirements.txt b/requirements.txt index 4ab553c59fb25f9037863b9ff8728087e2cc7685..8cb776f0a752ee9c9cb1a00d5bab959a8f778bfa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,6 @@ setuptools timm transformers wheel -# torch --index-url https://download.pytorch.org/whl/cu118 -# torchvision --index-url https://download.pytorch.org/whl/cu118 \ No newline at end of file +--extra-index-url https://download.pytorch.org/whl/cu118 +torch==2.1.1 +torchvision==0.16.1