Spaces:
Sleeping
Sleeping
Upload 28 files
Browse files- app.py +62 -0
- causal-conv1d/.github/workflows/publish.yaml +209 -0
- causal-conv1d/.gitignore +6 -0
- causal-conv1d/AUTHORS +1 -0
- causal-conv1d/LICENSE +29 -0
- causal-conv1d/README.md +43 -0
- causal-conv1d/build/lib/causal_conv1d/__init__.py +3 -0
- causal-conv1d/build/lib/causal_conv1d/causal_conv1d_interface.py +239 -0
- causal-conv1d/build/lib/causal_conv1d/causal_conv1d_varlen.py +86 -0
- causal-conv1d/causal_conv1d.egg-info/PKG-INFO +62 -0
- causal-conv1d/causal_conv1d.egg-info/SOURCES.txt +12 -0
- causal-conv1d/causal_conv1d.egg-info/dependency_links.txt +1 -0
- causal-conv1d/causal_conv1d.egg-info/requires.txt +3 -0
- causal-conv1d/causal_conv1d.egg-info/top_level.txt +1 -0
- causal-conv1d/causal_conv1d/__init__.py +3 -0
- causal-conv1d/causal_conv1d/causal_conv1d_interface.py +239 -0
- causal-conv1d/causal_conv1d/causal_conv1d_varlen.py +86 -0
- causal-conv1d/csrc/causal_conv1d.cpp +464 -0
- causal-conv1d/csrc/causal_conv1d.h +77 -0
- causal-conv1d/csrc/causal_conv1d_bwd.cu +627 -0
- causal-conv1d/csrc/causal_conv1d_common.h +98 -0
- causal-conv1d/csrc/causal_conv1d_fwd.cu +399 -0
- causal-conv1d/csrc/causal_conv1d_update.cu +130 -0
- causal-conv1d/csrc/static_switch.h +25 -0
- causal-conv1d/dist/causal_conv1d-1.4.0-py3.9.egg +0 -0
- causal-conv1d/rocm_patch/rocm6_0.patch +56 -0
- causal-conv1d/setup.py +296 -0
- causal-conv1d/tests/test_causal_conv1d.py +301 -0
app.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
def generate_prompt(instruction, input=""):
|
| 6 |
+
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
| 7 |
+
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
|
| 8 |
+
if input:
|
| 9 |
+
return f"""Instruction: {instruction}
|
| 10 |
+
|
| 11 |
+
Input: {input}
|
| 12 |
+
|
| 13 |
+
Response:"""
|
| 14 |
+
else:
|
| 15 |
+
return f"""User: hi
|
| 16 |
+
|
| 17 |
+
Lover: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
| 18 |
+
|
| 19 |
+
User: {instruction}
|
| 20 |
+
|
| 21 |
+
Lover:"""
|
| 22 |
+
|
| 23 |
+
model_path = "models/rwkv-6-world-1b6/" # Path to your local model directory
|
| 24 |
+
|
| 25 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 26 |
+
model_path,
|
| 27 |
+
trust_remote_code=True,
|
| 28 |
+
use_flash_attention_2=False # Explicitly disable Flash Attention
|
| 29 |
+
).to(torch.float32)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 33 |
+
model_path,
|
| 34 |
+
bos_token="</s>",
|
| 35 |
+
eos_token="</ s>",
|
| 36 |
+
unk_token="<unk>",
|
| 37 |
+
pad_token="<pad>",
|
| 38 |
+
trust_remote_code=True,
|
| 39 |
+
padding_side='left',
|
| 40 |
+
clean_up_tokenization_spaces=False # Or set to True if you prefer
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
print(tokenizer.special_tokens_map)
|
| 44 |
+
|
| 45 |
+
text = "Hi"
|
| 46 |
+
|
| 47 |
+
prompt = generate_prompt(text)
|
| 48 |
+
|
| 49 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
| 50 |
+
|
| 51 |
+
# Generate text word by word with stop sequence
|
| 52 |
+
generated_text = ""
|
| 53 |
+
for i in range(333): # Generate up to 333 tokens
|
| 54 |
+
output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0)
|
| 55 |
+
new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True)
|
| 56 |
+
|
| 57 |
+
print(new_word, end="", flush=True) # Print word-by-word
|
| 58 |
+
generated_text += new_word
|
| 59 |
+
|
| 60 |
+
input_ids = output # Update input_ids for next iteration
|
| 61 |
+
|
| 62 |
+
print() # Add a newline at the end
|
causal-conv1d/.github/workflows/publish.yaml
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This workflow will:
|
| 2 |
+
# - Create a new Github release
|
| 3 |
+
# - Build wheels for supported architectures
|
| 4 |
+
# - Deploy the wheels to the Github release
|
| 5 |
+
# - Release the static code to PyPi
|
| 6 |
+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
| 7 |
+
|
| 8 |
+
name: Build wheels and deploy
|
| 9 |
+
|
| 10 |
+
on:
|
| 11 |
+
create:
|
| 12 |
+
tags:
|
| 13 |
+
- v*
|
| 14 |
+
|
| 15 |
+
jobs:
|
| 16 |
+
|
| 17 |
+
setup_release:
|
| 18 |
+
name: Create Release
|
| 19 |
+
runs-on: ubuntu-latest
|
| 20 |
+
steps:
|
| 21 |
+
- name: Get the tag version
|
| 22 |
+
id: extract_branch
|
| 23 |
+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
| 24 |
+
shell: bash
|
| 25 |
+
|
| 26 |
+
- name: Create Release
|
| 27 |
+
id: create_release
|
| 28 |
+
uses: actions/create-release@v1
|
| 29 |
+
env:
|
| 30 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 31 |
+
with:
|
| 32 |
+
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
| 33 |
+
release_name: ${{ steps.extract_branch.outputs.branch }}
|
| 34 |
+
|
| 35 |
+
build_wheels:
|
| 36 |
+
name: Build Wheel
|
| 37 |
+
needs: setup_release
|
| 38 |
+
runs-on: ${{ matrix.os }}
|
| 39 |
+
|
| 40 |
+
strategy:
|
| 41 |
+
fail-fast: false
|
| 42 |
+
matrix:
|
| 43 |
+
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
|
| 44 |
+
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
|
| 45 |
+
os: [ubuntu-20.04]
|
| 46 |
+
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
|
| 47 |
+
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240505']
|
| 48 |
+
cuda-version: ['11.8.0', '12.2.2']
|
| 49 |
+
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
|
| 50 |
+
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
|
| 51 |
+
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
|
| 52 |
+
# when building without C++11 ABI and using it on nvcr images.
|
| 53 |
+
cxx11_abi: ['FALSE', 'TRUE']
|
| 54 |
+
exclude:
|
| 55 |
+
# Pytorch < 2.2 does not support Python 3.12
|
| 56 |
+
- torch-version: '2.0.1'
|
| 57 |
+
python-version: '3.12'
|
| 58 |
+
- torch-version: '2.1.2'
|
| 59 |
+
python-version: '3.12'
|
| 60 |
+
# Pytorch <= 2.0 only supports CUDA <= 11.8
|
| 61 |
+
- torch-version: '2.0.1'
|
| 62 |
+
cuda-version: '12.2.2'
|
| 63 |
+
|
| 64 |
+
steps:
|
| 65 |
+
- name: Checkout
|
| 66 |
+
uses: actions/checkout@v3
|
| 67 |
+
|
| 68 |
+
- name: Set up Python
|
| 69 |
+
uses: actions/setup-python@v4
|
| 70 |
+
with:
|
| 71 |
+
python-version: ${{ matrix.python-version }}
|
| 72 |
+
|
| 73 |
+
- name: Set CUDA and PyTorch versions
|
| 74 |
+
run: |
|
| 75 |
+
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
|
| 76 |
+
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
|
| 77 |
+
|
| 78 |
+
- name: Free up disk space
|
| 79 |
+
if: ${{ runner.os == 'Linux' }}
|
| 80 |
+
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
|
| 81 |
+
# https://github.com/easimon/maximize-build-space/tree/test-report
|
| 82 |
+
run: |
|
| 83 |
+
sudo rm -rf /usr/share/dotnet
|
| 84 |
+
sudo rm -rf /opt/ghc
|
| 85 |
+
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
| 86 |
+
|
| 87 |
+
- name: Set up swap space
|
| 88 |
+
if: runner.os == 'Linux'
|
| 89 |
+
uses: pierotofy/[email protected]
|
| 90 |
+
with:
|
| 91 |
+
swap-size-gb: 10
|
| 92 |
+
|
| 93 |
+
- name: Install CUDA ${{ matrix.cuda-version }}
|
| 94 |
+
if: ${{ matrix.cuda-version != 'cpu' }}
|
| 95 |
+
uses: Jimver/[email protected]
|
| 96 |
+
id: cuda-toolkit
|
| 97 |
+
with:
|
| 98 |
+
cuda: ${{ matrix.cuda-version }}
|
| 99 |
+
linux-local-args: '["--toolkit"]'
|
| 100 |
+
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
|
| 101 |
+
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
|
| 102 |
+
method: 'network'
|
| 103 |
+
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
|
| 104 |
+
# not just nvcc
|
| 105 |
+
# sub-packages: '["nvcc"]'
|
| 106 |
+
|
| 107 |
+
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
|
| 108 |
+
run: |
|
| 109 |
+
pip install --upgrade pip
|
| 110 |
+
# If we don't install before installing Pytorch, we get error for torch 2.0.1
|
| 111 |
+
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
|
| 112 |
+
pip install lit
|
| 113 |
+
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
|
| 114 |
+
pip install setuptools==68.0.0
|
| 115 |
+
# We want to figure out the CUDA version to download pytorch
|
| 116 |
+
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
|
| 117 |
+
# This code is ugly, maybe there's a better way to do this.
|
| 118 |
+
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
|
| 119 |
+
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
|
| 120 |
+
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
|
| 121 |
+
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
|
| 122 |
+
)
|
| 123 |
+
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
|
| 124 |
+
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
|
| 125 |
+
else
|
| 126 |
+
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
|
| 127 |
+
fi
|
| 128 |
+
nvcc --version
|
| 129 |
+
python --version
|
| 130 |
+
python -c "import torch; print('PyTorch:', torch.__version__)"
|
| 131 |
+
python -c "import torch; print('CUDA:', torch.version.cuda)"
|
| 132 |
+
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
| 133 |
+
shell:
|
| 134 |
+
bash
|
| 135 |
+
|
| 136 |
+
- name: Build wheel
|
| 137 |
+
run: |
|
| 138 |
+
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
|
| 139 |
+
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
|
| 140 |
+
# However this still fails so I'm using a newer version of setuptools
|
| 141 |
+
pip install setuptools==68.0.0
|
| 142 |
+
pip install ninja packaging wheel
|
| 143 |
+
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
|
| 144 |
+
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
| 145 |
+
# Limit MAX_JOBS otherwise the github runner goes OOM
|
| 146 |
+
MAX_JOBS=2 CAUSAL_CONV1D_FORCE_BUILD="TRUE" CAUSAL_CONV1D_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
|
| 147 |
+
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
|
| 148 |
+
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
|
| 149 |
+
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
|
| 150 |
+
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
| 151 |
+
|
| 152 |
+
- name: Log Built Wheels
|
| 153 |
+
run: |
|
| 154 |
+
ls dist
|
| 155 |
+
|
| 156 |
+
- name: Get the tag version
|
| 157 |
+
id: extract_branch
|
| 158 |
+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
| 159 |
+
|
| 160 |
+
- name: Get Release with tag
|
| 161 |
+
id: get_current_release
|
| 162 |
+
uses: joutvhu/get-release@v1
|
| 163 |
+
with:
|
| 164 |
+
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
| 165 |
+
env:
|
| 166 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 167 |
+
|
| 168 |
+
- name: Upload Release Asset
|
| 169 |
+
id: upload_release_asset
|
| 170 |
+
uses: actions/upload-release-asset@v1
|
| 171 |
+
env:
|
| 172 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 173 |
+
with:
|
| 174 |
+
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
| 175 |
+
asset_path: ./dist/${{env.wheel_name}}
|
| 176 |
+
asset_name: ${{env.wheel_name}}
|
| 177 |
+
asset_content_type: application/*
|
| 178 |
+
|
| 179 |
+
publish_package:
|
| 180 |
+
name: Publish package
|
| 181 |
+
needs: [build_wheels]
|
| 182 |
+
|
| 183 |
+
runs-on: ubuntu-latest
|
| 184 |
+
|
| 185 |
+
steps:
|
| 186 |
+
- uses: actions/checkout@v3
|
| 187 |
+
|
| 188 |
+
- uses: actions/setup-python@v4
|
| 189 |
+
with:
|
| 190 |
+
python-version: '3.10'
|
| 191 |
+
|
| 192 |
+
- name: Install dependencies
|
| 193 |
+
run: |
|
| 194 |
+
pip install ninja packaging setuptools wheel twine
|
| 195 |
+
# We don't want to download anything CUDA-related here
|
| 196 |
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 197 |
+
|
| 198 |
+
- name: Build core package
|
| 199 |
+
env:
|
| 200 |
+
CAUSAL_CONV1D_SKIP_CUDA_BUILD: "TRUE"
|
| 201 |
+
run: |
|
| 202 |
+
python setup.py sdist --dist-dir=dist
|
| 203 |
+
|
| 204 |
+
- name: Deploy
|
| 205 |
+
env:
|
| 206 |
+
TWINE_USERNAME: "__token__"
|
| 207 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
| 208 |
+
run: |
|
| 209 |
+
python -m twine upload dist/*
|
causal-conv1d/.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*__pycache__/
|
| 2 |
+
*.egg-info/
|
| 3 |
+
build/
|
| 4 |
+
**.so
|
| 5 |
+
*.hip
|
| 6 |
+
*_hip.*
|
causal-conv1d/AUTHORS
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Tri Dao, [email protected]
|
causal-conv1d/LICENSE
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD 3-Clause License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
| 4 |
+
All rights reserved.
|
| 5 |
+
|
| 6 |
+
Redistribution and use in source and binary forms, with or without
|
| 7 |
+
modification, are permitted provided that the following conditions are met:
|
| 8 |
+
|
| 9 |
+
* Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
list of conditions and the following disclaimer.
|
| 11 |
+
|
| 12 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
and/or other materials provided with the distribution.
|
| 15 |
+
|
| 16 |
+
* Neither the name of the copyright holder nor the names of its
|
| 17 |
+
contributors may be used to endorse or promote products derived from
|
| 18 |
+
this software without specific prior written permission.
|
| 19 |
+
|
| 20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
causal-conv1d/README.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Causal depthwise conv1d in CUDA with a PyTorch interface
|
| 2 |
+
|
| 3 |
+
Features:
|
| 4 |
+
- Support fp32, fp16, bf16.
|
| 5 |
+
- Kernel size 2, 3, 4.
|
| 6 |
+
|
| 7 |
+
## How to use
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
from causal_conv1d import causal_conv1d_fn
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
```
|
| 14 |
+
def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
| 15 |
+
"""
|
| 16 |
+
x: (batch, dim, seqlen)
|
| 17 |
+
weight: (dim, width)
|
| 18 |
+
bias: (dim,)
|
| 19 |
+
activation: either None or "silu" or "swish"
|
| 20 |
+
|
| 21 |
+
out: (batch, dim, seqlen)
|
| 22 |
+
"""
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Equivalent to:
|
| 26 |
+
```
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
|
| 29 |
+
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## Additional Prerequisites for AMD cards
|
| 33 |
+
|
| 34 |
+
### Patching ROCm
|
| 35 |
+
|
| 36 |
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
| 37 |
+
|
| 38 |
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
| 39 |
+
|
| 40 |
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
| 41 |
+
```bash
|
| 42 |
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
| 43 |
+
```
|
causal-conv1d/build/lib/causal_conv1d/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "1.4.0"
|
| 2 |
+
|
| 3 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
|
causal-conv1d/build/lib/causal_conv1d/causal_conv1d_interface.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import causal_conv1d_cuda
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CausalConv1dFn(torch.autograd.Function):
|
| 11 |
+
@staticmethod
|
| 12 |
+
def forward(
|
| 13 |
+
ctx,
|
| 14 |
+
x,
|
| 15 |
+
weight,
|
| 16 |
+
bias=None,
|
| 17 |
+
seq_idx=None,
|
| 18 |
+
initial_states=None,
|
| 19 |
+
return_final_states=False,
|
| 20 |
+
final_states_out=None,
|
| 21 |
+
activation=None,
|
| 22 |
+
):
|
| 23 |
+
if activation not in [None, "silu", "swish"]:
|
| 24 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 25 |
+
if x.stride(2) != 1 and x.stride(1) != 1:
|
| 26 |
+
x = x.contiguous()
|
| 27 |
+
bias = bias.contiguous() if bias is not None else None
|
| 28 |
+
if seq_idx is not None:
|
| 29 |
+
assert (
|
| 30 |
+
initial_states is None
|
| 31 |
+
), "initial_states must be None if seq_idx is not None"
|
| 32 |
+
assert (
|
| 33 |
+
not return_final_states
|
| 34 |
+
), "If seq_idx is not None, we don't return final_states_out"
|
| 35 |
+
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
| 36 |
+
if initial_states is not None and (
|
| 37 |
+
initial_states.stride(2) != 1 and initial_states.stride(1) != 1
|
| 38 |
+
):
|
| 39 |
+
initial_states = initial_states.contiguous()
|
| 40 |
+
if return_final_states:
|
| 41 |
+
assert (
|
| 42 |
+
x.stride(1) == 1
|
| 43 |
+
), "Only channel-last layout support returning final_states_out"
|
| 44 |
+
if final_states_out is not None:
|
| 45 |
+
assert (
|
| 46 |
+
final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
batch, dim, seqlen = x.shape
|
| 50 |
+
width = weight.shape[1]
|
| 51 |
+
final_states_out = torch.empty(
|
| 52 |
+
batch, width - 1, dim, device=x.device, dtype=x.dtype
|
| 53 |
+
).transpose(1, 2)
|
| 54 |
+
else:
|
| 55 |
+
final_states_out = None
|
| 56 |
+
ctx.activation = activation in ["silu", "swish"]
|
| 57 |
+
out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 58 |
+
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
|
| 59 |
+
)
|
| 60 |
+
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
|
| 61 |
+
ctx.return_final_states = return_final_states
|
| 62 |
+
ctx.return_dinitial_states = (
|
| 63 |
+
initial_states is not None and initial_states.requires_grad
|
| 64 |
+
)
|
| 65 |
+
return out if not return_final_states else (out, final_states_out)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def backward(ctx, dout, *args):
|
| 69 |
+
x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
|
| 70 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
| 71 |
+
if dout.stride(2) != 1 and dout.stride(1) != 1:
|
| 72 |
+
dout = dout.contiguous()
|
| 73 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
| 74 |
+
# backward of conv1d with the backward of chunk).
|
| 75 |
+
# Here we just pass in None and dx will be allocated in the C++ code.
|
| 76 |
+
dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 77 |
+
x,
|
| 78 |
+
weight,
|
| 79 |
+
bias,
|
| 80 |
+
dout,
|
| 81 |
+
seq_idx,
|
| 82 |
+
initial_states,
|
| 83 |
+
dfinal_states,
|
| 84 |
+
None,
|
| 85 |
+
ctx.return_dinitial_states,
|
| 86 |
+
ctx.activation,
|
| 87 |
+
)
|
| 88 |
+
return (
|
| 89 |
+
dx,
|
| 90 |
+
dweight,
|
| 91 |
+
dbias if bias is not None else None,
|
| 92 |
+
None,
|
| 93 |
+
dinitial_states if initial_states is not None else None,
|
| 94 |
+
None,
|
| 95 |
+
None,
|
| 96 |
+
None,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def causal_conv1d_fn(
|
| 101 |
+
x,
|
| 102 |
+
weight,
|
| 103 |
+
bias=None,
|
| 104 |
+
seq_idx=None,
|
| 105 |
+
initial_states=None,
|
| 106 |
+
return_final_states=False,
|
| 107 |
+
final_states_out=None,
|
| 108 |
+
activation=None,
|
| 109 |
+
):
|
| 110 |
+
"""
|
| 111 |
+
x: (batch, dim, seqlen)
|
| 112 |
+
weight: (dim, width)
|
| 113 |
+
bias: (dim,)
|
| 114 |
+
seq_idx: (batch, seqlen)
|
| 115 |
+
initial_states: (batch, dim, width - 1)
|
| 116 |
+
final_states_out: (batch, dim, width - 1), to be written to
|
| 117 |
+
activation: either None or "silu" or "swish"
|
| 118 |
+
|
| 119 |
+
out: (batch, dim, seqlen)
|
| 120 |
+
"""
|
| 121 |
+
return CausalConv1dFn.apply(
|
| 122 |
+
x,
|
| 123 |
+
weight,
|
| 124 |
+
bias,
|
| 125 |
+
seq_idx,
|
| 126 |
+
initial_states,
|
| 127 |
+
return_final_states,
|
| 128 |
+
final_states_out,
|
| 129 |
+
activation,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def causal_conv1d_ref(
|
| 134 |
+
x,
|
| 135 |
+
weight,
|
| 136 |
+
bias=None,
|
| 137 |
+
initial_states=None,
|
| 138 |
+
return_final_states=False,
|
| 139 |
+
final_states_out=None,
|
| 140 |
+
activation=None,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
x: (batch, dim, seqlen)
|
| 144 |
+
weight: (dim, width)
|
| 145 |
+
bias: (dim,)
|
| 146 |
+
initial_states: (batch, dim, width - 1)
|
| 147 |
+
final_states_out: (batch, dim, width - 1)
|
| 148 |
+
|
| 149 |
+
out: (batch, dim, seqlen)
|
| 150 |
+
"""
|
| 151 |
+
if activation not in [None, "silu", "swish"]:
|
| 152 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 153 |
+
dtype_in = x.dtype
|
| 154 |
+
x = x.to(weight.dtype)
|
| 155 |
+
seqlen = x.shape[-1]
|
| 156 |
+
dim, width = weight.shape
|
| 157 |
+
if initial_states is None:
|
| 158 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
| 159 |
+
else:
|
| 160 |
+
x = torch.cat([initial_states, x], dim=-1)
|
| 161 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
| 162 |
+
out = out[..., :seqlen]
|
| 163 |
+
if return_final_states:
|
| 164 |
+
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
| 165 |
+
dtype_in
|
| 166 |
+
) # (batch, dim, width - 1)
|
| 167 |
+
if final_states_out is not None:
|
| 168 |
+
final_states_out.copy_(final_states)
|
| 169 |
+
else:
|
| 170 |
+
final_states_out = final_states
|
| 171 |
+
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
| 172 |
+
return out if not return_final_states else (out, final_states_out)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
| 176 |
+
"""
|
| 177 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
| 178 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
| 179 |
+
weight: (dim, width)
|
| 180 |
+
bias: (dim,)
|
| 181 |
+
cache_seqlens: (batch,), dtype int32.
|
| 182 |
+
If not None, the conv_state is treated as a circular buffer.
|
| 183 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
| 184 |
+
@cache_seqlens % state_len.
|
| 185 |
+
|
| 186 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
| 187 |
+
"""
|
| 188 |
+
if activation not in [None, "silu", "swish"]:
|
| 189 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 190 |
+
activation = activation in ["silu", "swish"]
|
| 191 |
+
unsqueeze = x.dim() == 2
|
| 192 |
+
if unsqueeze:
|
| 193 |
+
x = x.unsqueeze(-1)
|
| 194 |
+
out = causal_conv1d_cuda.causal_conv1d_update(
|
| 195 |
+
x, conv_state, weight, bias, activation, cache_seqlens
|
| 196 |
+
)
|
| 197 |
+
if unsqueeze:
|
| 198 |
+
out = out.squeeze(-1)
|
| 199 |
+
return out
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
| 203 |
+
"""
|
| 204 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
| 205 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
| 206 |
+
weight: (dim, width)
|
| 207 |
+
bias: (dim,)
|
| 208 |
+
cache_seqlens: (batch,), dtype int32.
|
| 209 |
+
If not None, the conv_state is treated as a circular buffer.
|
| 210 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
| 211 |
+
@cache_seqlens % state_len before performing the convolution.
|
| 212 |
+
|
| 213 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
| 214 |
+
"""
|
| 215 |
+
if activation not in [None, "silu", "swish"]:
|
| 216 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 217 |
+
dtype_in = x.dtype
|
| 218 |
+
unsqueeze = x.dim() == 2
|
| 219 |
+
if unsqueeze:
|
| 220 |
+
x = x.unsqueeze(-1)
|
| 221 |
+
batch, dim, seqlen = x.shape
|
| 222 |
+
width = weight.shape[1]
|
| 223 |
+
state_len = conv_state.shape[-1]
|
| 224 |
+
assert conv_state.shape == (batch, dim, state_len)
|
| 225 |
+
assert weight.shape == (dim, width)
|
| 226 |
+
if cache_seqlens is None:
|
| 227 |
+
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
|
| 228 |
+
conv_state.copy_(x_new[:, :, -state_len:])
|
| 229 |
+
else:
|
| 230 |
+
width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
| 231 |
+
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
| 232 |
+
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
| 233 |
+
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
| 234 |
+
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
| 235 |
+
conv_state.scatter_(2, copy_idx, x)
|
| 236 |
+
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
|
| 237 |
+
if unsqueeze:
|
| 238 |
+
out = out.squeeze(-1)
|
| 239 |
+
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
causal-conv1d/build/lib/causal_conv1d/causal_conv1d_varlen.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@triton.jit
|
| 9 |
+
def _causal_conv1d_varlen_states(
|
| 10 |
+
X,
|
| 11 |
+
CU_SEQLENS,
|
| 12 |
+
STATES,
|
| 13 |
+
state_len,
|
| 14 |
+
dim,
|
| 15 |
+
stride_x_seqlen, stride_x_dim,
|
| 16 |
+
stride_states_batch, stride_states_seqlen, stride_states_dim,
|
| 17 |
+
BLOCK_M: tl.constexpr,
|
| 18 |
+
BLOCK_N: tl.constexpr
|
| 19 |
+
):
|
| 20 |
+
batch_idx = tl.program_id(2)
|
| 21 |
+
STATES += batch_idx * stride_states_batch
|
| 22 |
+
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
|
| 23 |
+
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
|
| 24 |
+
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 25 |
+
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 26 |
+
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
|
| 27 |
+
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
|
| 28 |
+
other=0)
|
| 29 |
+
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 30 |
+
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
|
| 31 |
+
x,
|
| 32 |
+
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
| 36 |
+
"""
|
| 37 |
+
Forward pass only, does not support backward pass.
|
| 38 |
+
Parameters:
|
| 39 |
+
x: (total_tokens, dim)
|
| 40 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
| 41 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
| 42 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
| 43 |
+
Return:
|
| 44 |
+
states: (batch, dim, state_len)
|
| 45 |
+
"""
|
| 46 |
+
_, dim = x.shape
|
| 47 |
+
batch = cu_seqlens.shape[0] - 1
|
| 48 |
+
cu_seqlens = cu_seqlens.contiguous()
|
| 49 |
+
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
| 50 |
+
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
|
| 51 |
+
BLOCK_N = min(triton.next_power_of_2(dim), 256)
|
| 52 |
+
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
|
| 53 |
+
with torch.cuda.device(x.device.index):
|
| 54 |
+
_causal_conv1d_varlen_states[grid](
|
| 55 |
+
x,
|
| 56 |
+
cu_seqlens,
|
| 57 |
+
states,
|
| 58 |
+
state_len,
|
| 59 |
+
dim,
|
| 60 |
+
x.stride(0), x.stride(1),
|
| 61 |
+
states.stride(0), states.stride(2), states.stride(1),
|
| 62 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
| 63 |
+
)
|
| 64 |
+
return states
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Forward pass only, does not support backward pass.
|
| 70 |
+
Parameters:
|
| 71 |
+
x: (total_tokens, dim)
|
| 72 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
| 73 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
| 74 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
| 75 |
+
Return:
|
| 76 |
+
states: (batch, dim, state_len)
|
| 77 |
+
"""
|
| 78 |
+
_, dim = x.shape
|
| 79 |
+
batch = cu_seqlens.shape[0] - 1
|
| 80 |
+
cu_seqlens = cu_seqlens.contiguous()
|
| 81 |
+
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
| 82 |
+
for i in range(batch):
|
| 83 |
+
end_idx = cu_seqlens[i + 1]
|
| 84 |
+
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
|
| 85 |
+
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
|
| 86 |
+
return states
|
causal-conv1d/causal_conv1d.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: causal-conv1d
|
| 3 |
+
Version: 1.4.0
|
| 4 |
+
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
|
| 5 |
+
Home-page: https://github.com/Dao-AILab/causal-conv1d
|
| 6 |
+
Author: Tri Dao
|
| 7 |
+
Author-email: [email protected]
|
| 8 |
+
License: UNKNOWN
|
| 9 |
+
Platform: UNKNOWN
|
| 10 |
+
Classifier: Programming Language :: Python :: 3
|
| 11 |
+
Classifier: License :: OSI Approved :: BSD License
|
| 12 |
+
Classifier: Operating System :: Unix
|
| 13 |
+
Requires-Python: >=3.8
|
| 14 |
+
Description-Content-Type: text/markdown
|
| 15 |
+
License-File: LICENSE
|
| 16 |
+
License-File: AUTHORS
|
| 17 |
+
|
| 18 |
+
# Causal depthwise conv1d in CUDA with a PyTorch interface
|
| 19 |
+
|
| 20 |
+
Features:
|
| 21 |
+
- Support fp32, fp16, bf16.
|
| 22 |
+
- Kernel size 2, 3, 4.
|
| 23 |
+
|
| 24 |
+
## How to use
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
from causal_conv1d import causal_conv1d_fn
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
```
|
| 31 |
+
def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
| 32 |
+
"""
|
| 33 |
+
x: (batch, dim, seqlen)
|
| 34 |
+
weight: (dim, width)
|
| 35 |
+
bias: (dim,)
|
| 36 |
+
activation: either None or "silu" or "swish"
|
| 37 |
+
|
| 38 |
+
out: (batch, dim, seqlen)
|
| 39 |
+
"""
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Equivalent to:
|
| 43 |
+
```
|
| 44 |
+
import torch.nn.functional as F
|
| 45 |
+
|
| 46 |
+
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Additional Prerequisites for AMD cards
|
| 50 |
+
|
| 51 |
+
### Patching ROCm
|
| 52 |
+
|
| 53 |
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
| 54 |
+
|
| 55 |
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
| 56 |
+
|
| 57 |
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
| 58 |
+
```bash
|
| 59 |
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
|
causal-conv1d/causal_conv1d.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
AUTHORS
|
| 2 |
+
LICENSE
|
| 3 |
+
README.md
|
| 4 |
+
setup.py
|
| 5 |
+
causal_conv1d/__init__.py
|
| 6 |
+
causal_conv1d/causal_conv1d_interface.py
|
| 7 |
+
causal_conv1d/causal_conv1d_varlen.py
|
| 8 |
+
causal_conv1d.egg-info/PKG-INFO
|
| 9 |
+
causal_conv1d.egg-info/SOURCES.txt
|
| 10 |
+
causal_conv1d.egg-info/dependency_links.txt
|
| 11 |
+
causal_conv1d.egg-info/requires.txt
|
| 12 |
+
causal_conv1d.egg-info/top_level.txt
|
causal-conv1d/causal_conv1d.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
causal-conv1d/causal_conv1d.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
packaging
|
| 3 |
+
ninja
|
causal-conv1d/causal_conv1d.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
causal_conv1d
|
causal-conv1d/causal_conv1d/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "1.4.0"
|
| 2 |
+
|
| 3 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
|
causal-conv1d/causal_conv1d/causal_conv1d_interface.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import causal_conv1d_cuda
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CausalConv1dFn(torch.autograd.Function):
|
| 11 |
+
@staticmethod
|
| 12 |
+
def forward(
|
| 13 |
+
ctx,
|
| 14 |
+
x,
|
| 15 |
+
weight,
|
| 16 |
+
bias=None,
|
| 17 |
+
seq_idx=None,
|
| 18 |
+
initial_states=None,
|
| 19 |
+
return_final_states=False,
|
| 20 |
+
final_states_out=None,
|
| 21 |
+
activation=None,
|
| 22 |
+
):
|
| 23 |
+
if activation not in [None, "silu", "swish"]:
|
| 24 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 25 |
+
if x.stride(2) != 1 and x.stride(1) != 1:
|
| 26 |
+
x = x.contiguous()
|
| 27 |
+
bias = bias.contiguous() if bias is not None else None
|
| 28 |
+
if seq_idx is not None:
|
| 29 |
+
assert (
|
| 30 |
+
initial_states is None
|
| 31 |
+
), "initial_states must be None if seq_idx is not None"
|
| 32 |
+
assert (
|
| 33 |
+
not return_final_states
|
| 34 |
+
), "If seq_idx is not None, we don't return final_states_out"
|
| 35 |
+
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
| 36 |
+
if initial_states is not None and (
|
| 37 |
+
initial_states.stride(2) != 1 and initial_states.stride(1) != 1
|
| 38 |
+
):
|
| 39 |
+
initial_states = initial_states.contiguous()
|
| 40 |
+
if return_final_states:
|
| 41 |
+
assert (
|
| 42 |
+
x.stride(1) == 1
|
| 43 |
+
), "Only channel-last layout support returning final_states_out"
|
| 44 |
+
if final_states_out is not None:
|
| 45 |
+
assert (
|
| 46 |
+
final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
batch, dim, seqlen = x.shape
|
| 50 |
+
width = weight.shape[1]
|
| 51 |
+
final_states_out = torch.empty(
|
| 52 |
+
batch, width - 1, dim, device=x.device, dtype=x.dtype
|
| 53 |
+
).transpose(1, 2)
|
| 54 |
+
else:
|
| 55 |
+
final_states_out = None
|
| 56 |
+
ctx.activation = activation in ["silu", "swish"]
|
| 57 |
+
out = causal_conv1d_cuda.causal_conv1d_fwd(
|
| 58 |
+
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
|
| 59 |
+
)
|
| 60 |
+
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
|
| 61 |
+
ctx.return_final_states = return_final_states
|
| 62 |
+
ctx.return_dinitial_states = (
|
| 63 |
+
initial_states is not None and initial_states.requires_grad
|
| 64 |
+
)
|
| 65 |
+
return out if not return_final_states else (out, final_states_out)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def backward(ctx, dout, *args):
|
| 69 |
+
x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
|
| 70 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
| 71 |
+
if dout.stride(2) != 1 and dout.stride(1) != 1:
|
| 72 |
+
dout = dout.contiguous()
|
| 73 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
| 74 |
+
# backward of conv1d with the backward of chunk).
|
| 75 |
+
# Here we just pass in None and dx will be allocated in the C++ code.
|
| 76 |
+
dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
|
| 77 |
+
x,
|
| 78 |
+
weight,
|
| 79 |
+
bias,
|
| 80 |
+
dout,
|
| 81 |
+
seq_idx,
|
| 82 |
+
initial_states,
|
| 83 |
+
dfinal_states,
|
| 84 |
+
None,
|
| 85 |
+
ctx.return_dinitial_states,
|
| 86 |
+
ctx.activation,
|
| 87 |
+
)
|
| 88 |
+
return (
|
| 89 |
+
dx,
|
| 90 |
+
dweight,
|
| 91 |
+
dbias if bias is not None else None,
|
| 92 |
+
None,
|
| 93 |
+
dinitial_states if initial_states is not None else None,
|
| 94 |
+
None,
|
| 95 |
+
None,
|
| 96 |
+
None,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def causal_conv1d_fn(
|
| 101 |
+
x,
|
| 102 |
+
weight,
|
| 103 |
+
bias=None,
|
| 104 |
+
seq_idx=None,
|
| 105 |
+
initial_states=None,
|
| 106 |
+
return_final_states=False,
|
| 107 |
+
final_states_out=None,
|
| 108 |
+
activation=None,
|
| 109 |
+
):
|
| 110 |
+
"""
|
| 111 |
+
x: (batch, dim, seqlen)
|
| 112 |
+
weight: (dim, width)
|
| 113 |
+
bias: (dim,)
|
| 114 |
+
seq_idx: (batch, seqlen)
|
| 115 |
+
initial_states: (batch, dim, width - 1)
|
| 116 |
+
final_states_out: (batch, dim, width - 1), to be written to
|
| 117 |
+
activation: either None or "silu" or "swish"
|
| 118 |
+
|
| 119 |
+
out: (batch, dim, seqlen)
|
| 120 |
+
"""
|
| 121 |
+
return CausalConv1dFn.apply(
|
| 122 |
+
x,
|
| 123 |
+
weight,
|
| 124 |
+
bias,
|
| 125 |
+
seq_idx,
|
| 126 |
+
initial_states,
|
| 127 |
+
return_final_states,
|
| 128 |
+
final_states_out,
|
| 129 |
+
activation,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def causal_conv1d_ref(
|
| 134 |
+
x,
|
| 135 |
+
weight,
|
| 136 |
+
bias=None,
|
| 137 |
+
initial_states=None,
|
| 138 |
+
return_final_states=False,
|
| 139 |
+
final_states_out=None,
|
| 140 |
+
activation=None,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
x: (batch, dim, seqlen)
|
| 144 |
+
weight: (dim, width)
|
| 145 |
+
bias: (dim,)
|
| 146 |
+
initial_states: (batch, dim, width - 1)
|
| 147 |
+
final_states_out: (batch, dim, width - 1)
|
| 148 |
+
|
| 149 |
+
out: (batch, dim, seqlen)
|
| 150 |
+
"""
|
| 151 |
+
if activation not in [None, "silu", "swish"]:
|
| 152 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 153 |
+
dtype_in = x.dtype
|
| 154 |
+
x = x.to(weight.dtype)
|
| 155 |
+
seqlen = x.shape[-1]
|
| 156 |
+
dim, width = weight.shape
|
| 157 |
+
if initial_states is None:
|
| 158 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
| 159 |
+
else:
|
| 160 |
+
x = torch.cat([initial_states, x], dim=-1)
|
| 161 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
| 162 |
+
out = out[..., :seqlen]
|
| 163 |
+
if return_final_states:
|
| 164 |
+
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
| 165 |
+
dtype_in
|
| 166 |
+
) # (batch, dim, width - 1)
|
| 167 |
+
if final_states_out is not None:
|
| 168 |
+
final_states_out.copy_(final_states)
|
| 169 |
+
else:
|
| 170 |
+
final_states_out = final_states
|
| 171 |
+
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
| 172 |
+
return out if not return_final_states else (out, final_states_out)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
| 176 |
+
"""
|
| 177 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
| 178 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
| 179 |
+
weight: (dim, width)
|
| 180 |
+
bias: (dim,)
|
| 181 |
+
cache_seqlens: (batch,), dtype int32.
|
| 182 |
+
If not None, the conv_state is treated as a circular buffer.
|
| 183 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
| 184 |
+
@cache_seqlens % state_len.
|
| 185 |
+
|
| 186 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
| 187 |
+
"""
|
| 188 |
+
if activation not in [None, "silu", "swish"]:
|
| 189 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 190 |
+
activation = activation in ["silu", "swish"]
|
| 191 |
+
unsqueeze = x.dim() == 2
|
| 192 |
+
if unsqueeze:
|
| 193 |
+
x = x.unsqueeze(-1)
|
| 194 |
+
out = causal_conv1d_cuda.causal_conv1d_update(
|
| 195 |
+
x, conv_state, weight, bias, activation, cache_seqlens
|
| 196 |
+
)
|
| 197 |
+
if unsqueeze:
|
| 198 |
+
out = out.squeeze(-1)
|
| 199 |
+
return out
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
| 203 |
+
"""
|
| 204 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
| 205 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
| 206 |
+
weight: (dim, width)
|
| 207 |
+
bias: (dim,)
|
| 208 |
+
cache_seqlens: (batch,), dtype int32.
|
| 209 |
+
If not None, the conv_state is treated as a circular buffer.
|
| 210 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
| 211 |
+
@cache_seqlens % state_len before performing the convolution.
|
| 212 |
+
|
| 213 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
| 214 |
+
"""
|
| 215 |
+
if activation not in [None, "silu", "swish"]:
|
| 216 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 217 |
+
dtype_in = x.dtype
|
| 218 |
+
unsqueeze = x.dim() == 2
|
| 219 |
+
if unsqueeze:
|
| 220 |
+
x = x.unsqueeze(-1)
|
| 221 |
+
batch, dim, seqlen = x.shape
|
| 222 |
+
width = weight.shape[1]
|
| 223 |
+
state_len = conv_state.shape[-1]
|
| 224 |
+
assert conv_state.shape == (batch, dim, state_len)
|
| 225 |
+
assert weight.shape == (dim, width)
|
| 226 |
+
if cache_seqlens is None:
|
| 227 |
+
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
|
| 228 |
+
conv_state.copy_(x_new[:, :, -state_len:])
|
| 229 |
+
else:
|
| 230 |
+
width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
| 231 |
+
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
| 232 |
+
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
| 233 |
+
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
| 234 |
+
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
| 235 |
+
conv_state.scatter_(2, copy_idx, x)
|
| 236 |
+
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
|
| 237 |
+
if unsqueeze:
|
| 238 |
+
out = out.squeeze(-1)
|
| 239 |
+
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
causal-conv1d/causal_conv1d/causal_conv1d_varlen.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@triton.jit
|
| 9 |
+
def _causal_conv1d_varlen_states(
|
| 10 |
+
X,
|
| 11 |
+
CU_SEQLENS,
|
| 12 |
+
STATES,
|
| 13 |
+
state_len,
|
| 14 |
+
dim,
|
| 15 |
+
stride_x_seqlen, stride_x_dim,
|
| 16 |
+
stride_states_batch, stride_states_seqlen, stride_states_dim,
|
| 17 |
+
BLOCK_M: tl.constexpr,
|
| 18 |
+
BLOCK_N: tl.constexpr
|
| 19 |
+
):
|
| 20 |
+
batch_idx = tl.program_id(2)
|
| 21 |
+
STATES += batch_idx * stride_states_batch
|
| 22 |
+
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
|
| 23 |
+
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
|
| 24 |
+
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 25 |
+
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 26 |
+
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
|
| 27 |
+
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
|
| 28 |
+
other=0)
|
| 29 |
+
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 30 |
+
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
|
| 31 |
+
x,
|
| 32 |
+
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
| 36 |
+
"""
|
| 37 |
+
Forward pass only, does not support backward pass.
|
| 38 |
+
Parameters:
|
| 39 |
+
x: (total_tokens, dim)
|
| 40 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
| 41 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
| 42 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
| 43 |
+
Return:
|
| 44 |
+
states: (batch, dim, state_len)
|
| 45 |
+
"""
|
| 46 |
+
_, dim = x.shape
|
| 47 |
+
batch = cu_seqlens.shape[0] - 1
|
| 48 |
+
cu_seqlens = cu_seqlens.contiguous()
|
| 49 |
+
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
| 50 |
+
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
|
| 51 |
+
BLOCK_N = min(triton.next_power_of_2(dim), 256)
|
| 52 |
+
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
|
| 53 |
+
with torch.cuda.device(x.device.index):
|
| 54 |
+
_causal_conv1d_varlen_states[grid](
|
| 55 |
+
x,
|
| 56 |
+
cu_seqlens,
|
| 57 |
+
states,
|
| 58 |
+
state_len,
|
| 59 |
+
dim,
|
| 60 |
+
x.stride(0), x.stride(1),
|
| 61 |
+
states.stride(0), states.stride(2), states.stride(1),
|
| 62 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
| 63 |
+
)
|
| 64 |
+
return states
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Forward pass only, does not support backward pass.
|
| 70 |
+
Parameters:
|
| 71 |
+
x: (total_tokens, dim)
|
| 72 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
| 73 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
| 74 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
| 75 |
+
Return:
|
| 76 |
+
states: (batch, dim, state_len)
|
| 77 |
+
"""
|
| 78 |
+
_, dim = x.shape
|
| 79 |
+
batch = cu_seqlens.shape[0] - 1
|
| 80 |
+
cu_seqlens = cu_seqlens.contiguous()
|
| 81 |
+
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
| 82 |
+
for i in range(batch):
|
| 83 |
+
end_idx = cu_seqlens[i + 1]
|
| 84 |
+
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
|
| 85 |
+
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
|
| 86 |
+
return states
|
causal-conv1d/csrc/causal_conv1d.cpp
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 6 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#include "causal_conv1d.h"
|
| 11 |
+
|
| 12 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 13 |
+
|
| 14 |
+
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
| 15 |
+
if (ITYPE == at::ScalarType::Half) { \
|
| 16 |
+
using input_t = at::Half; \
|
| 17 |
+
__VA_ARGS__(); \
|
| 18 |
+
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
| 19 |
+
using input_t = at::BFloat16; \
|
| 20 |
+
__VA_ARGS__(); \
|
| 21 |
+
} else if (ITYPE == at::ScalarType::Float) { \
|
| 22 |
+
using input_t = float; \
|
| 23 |
+
__VA_ARGS__(); \
|
| 24 |
+
} else { \
|
| 25 |
+
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
|
| 29 |
+
if (WTYPE == at::ScalarType::Half) { \
|
| 30 |
+
using weight_t = at::Half; \
|
| 31 |
+
__VA_ARGS__(); \
|
| 32 |
+
} else if (WTYPE == at::ScalarType::BFloat16) { \
|
| 33 |
+
using weight_t = at::BFloat16; \
|
| 34 |
+
__VA_ARGS__(); \
|
| 35 |
+
} else if (WTYPE == at::ScalarType::Float) { \
|
| 36 |
+
using weight_t = float; \
|
| 37 |
+
__VA_ARGS__(); \
|
| 38 |
+
} else { \
|
| 39 |
+
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
template<typename input_t, typename weight_t>
|
| 43 |
+
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 44 |
+
template <typename input_t, typename weight_t>
|
| 45 |
+
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 46 |
+
|
| 47 |
+
template<typename input_t, typename weight_t>
|
| 48 |
+
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 49 |
+
template<typename input_t, typename weight_t>
|
| 50 |
+
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 51 |
+
|
| 52 |
+
template<typename input_t, typename weight_t>
|
| 53 |
+
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 54 |
+
|
| 55 |
+
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
| 56 |
+
// sizes
|
| 57 |
+
const size_t batch,
|
| 58 |
+
const size_t dim,
|
| 59 |
+
const size_t seqlen,
|
| 60 |
+
const size_t width,
|
| 61 |
+
// device pointers
|
| 62 |
+
const at::Tensor x,
|
| 63 |
+
const at::Tensor weight,
|
| 64 |
+
const at::Tensor out,
|
| 65 |
+
void* bias_ptr,
|
| 66 |
+
bool silu_activation) {
|
| 67 |
+
|
| 68 |
+
// Reset the parameters
|
| 69 |
+
memset(¶ms, 0, sizeof(params));
|
| 70 |
+
|
| 71 |
+
params.batch = batch;
|
| 72 |
+
params.dim = dim;
|
| 73 |
+
params.seqlen = seqlen;
|
| 74 |
+
params.width = width;
|
| 75 |
+
|
| 76 |
+
params.silu_activation = silu_activation;
|
| 77 |
+
|
| 78 |
+
// Set the pointers and strides.
|
| 79 |
+
params.x_ptr = x.data_ptr();
|
| 80 |
+
params.weight_ptr = weight.data_ptr();
|
| 81 |
+
params.bias_ptr = bias_ptr;
|
| 82 |
+
params.out_ptr = out.data_ptr();
|
| 83 |
+
// All stride are in elements, not bytes.
|
| 84 |
+
params.x_batch_stride = x.stride(0);
|
| 85 |
+
params.x_c_stride = x.stride(1);
|
| 86 |
+
params.x_l_stride = x.stride(-1);
|
| 87 |
+
params.weight_c_stride = weight.stride(0);
|
| 88 |
+
params.weight_width_stride = weight.stride(1);
|
| 89 |
+
params.out_batch_stride = out.stride(0);
|
| 90 |
+
params.out_c_stride = out.stride(1);
|
| 91 |
+
params.out_l_stride = out.stride(-1);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
void set_conv_params_bwd(ConvParamsBwd ¶ms,
|
| 96 |
+
// sizes
|
| 97 |
+
const size_t batch,
|
| 98 |
+
const size_t dim,
|
| 99 |
+
const size_t seqlen,
|
| 100 |
+
const size_t width,
|
| 101 |
+
// device pointers
|
| 102 |
+
const at::Tensor x,
|
| 103 |
+
const at::Tensor weight,
|
| 104 |
+
void* bias_ptr,
|
| 105 |
+
const at::Tensor dout,
|
| 106 |
+
const at::Tensor dx,
|
| 107 |
+
const at::Tensor dweight,
|
| 108 |
+
void* dbias_ptr,
|
| 109 |
+
bool silu_activation) {
|
| 110 |
+
// Pass in "dout" instead of "out", we're not gonna use "out" at all.
|
| 111 |
+
set_conv_params_fwd(params, batch, dim, seqlen, width,
|
| 112 |
+
x, weight, dout, bias_ptr, silu_activation);
|
| 113 |
+
|
| 114 |
+
// Set the pointers and strides.
|
| 115 |
+
params.dout_ptr = dout.data_ptr();
|
| 116 |
+
params.dx_ptr = dx.data_ptr();
|
| 117 |
+
params.dweight_ptr = dweight.data_ptr();
|
| 118 |
+
params.dbias_ptr = dbias_ptr;
|
| 119 |
+
// All stride are in elements, not bytes.
|
| 120 |
+
params.dout_batch_stride = dout.stride(0);
|
| 121 |
+
params.dout_c_stride = dout.stride(1);
|
| 122 |
+
params.dout_l_stride = dout.stride(2);
|
| 123 |
+
params.dweight_c_stride = dweight.stride(0);
|
| 124 |
+
params.dweight_width_stride = dweight.stride(1);
|
| 125 |
+
params.dx_batch_stride = dx.stride(0);
|
| 126 |
+
params.dx_c_stride = dx.stride(1);
|
| 127 |
+
params.dx_l_stride = dx.stride(2);
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
at::Tensor
|
| 131 |
+
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
| 132 |
+
const c10::optional<at::Tensor> &bias_,
|
| 133 |
+
const c10::optional<at::Tensor> &seq_idx_,
|
| 134 |
+
const c10::optional<at::Tensor> &initial_states_,
|
| 135 |
+
c10::optional<at::Tensor> &final_states_out_,
|
| 136 |
+
bool silu_activation) {
|
| 137 |
+
auto input_type = x.scalar_type();
|
| 138 |
+
auto weight_type = weight.scalar_type();
|
| 139 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 140 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
| 141 |
+
|
| 142 |
+
TORCH_CHECK(x.is_cuda());
|
| 143 |
+
TORCH_CHECK(weight.is_cuda());
|
| 144 |
+
|
| 145 |
+
const auto sizes = x.sizes();
|
| 146 |
+
const int batch_size = sizes[0];
|
| 147 |
+
const int dim = sizes[1];
|
| 148 |
+
const int seqlen = sizes[2];
|
| 149 |
+
const int width = weight.size(-1);
|
| 150 |
+
|
| 151 |
+
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
| 152 |
+
CHECK_SHAPE(weight, dim, width);
|
| 153 |
+
|
| 154 |
+
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
| 155 |
+
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
| 156 |
+
|
| 157 |
+
if (is_channel_last) {
|
| 158 |
+
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
| 159 |
+
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
| 160 |
+
}
|
| 161 |
+
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
| 162 |
+
|
| 163 |
+
if (bias_.has_value()) {
|
| 164 |
+
auto bias = bias_.value();
|
| 165 |
+
TORCH_CHECK(bias.scalar_type() == weight_type);
|
| 166 |
+
TORCH_CHECK(bias.is_cuda());
|
| 167 |
+
TORCH_CHECK(bias.stride(-1) == 1);
|
| 168 |
+
CHECK_SHAPE(bias, dim);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
if (seq_idx_.has_value()) {
|
| 172 |
+
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
| 173 |
+
auto seq_idx = seq_idx_.value();
|
| 174 |
+
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
| 175 |
+
TORCH_CHECK(seq_idx.is_cuda());
|
| 176 |
+
TORCH_CHECK(seq_idx.is_contiguous());
|
| 177 |
+
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
at::Tensor out = torch::empty_like(x);
|
| 181 |
+
|
| 182 |
+
ConvParamsBase params;
|
| 183 |
+
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
| 184 |
+
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
| 185 |
+
silu_activation);
|
| 186 |
+
|
| 187 |
+
if (seq_idx_.has_value()) {
|
| 188 |
+
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
| 189 |
+
} else {
|
| 190 |
+
params.seq_idx_ptr = nullptr;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
if (initial_states_.has_value()) {
|
| 194 |
+
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
| 195 |
+
auto initial_states = initial_states_.value();
|
| 196 |
+
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
| 197 |
+
TORCH_CHECK(initial_states.is_cuda());
|
| 198 |
+
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
| 199 |
+
TORCH_CHECK(initial_states.stride(1) == 1);
|
| 200 |
+
params.initial_states_ptr = initial_states.data_ptr();
|
| 201 |
+
params.initial_states_batch_stride = initial_states.stride(0);
|
| 202 |
+
params.initial_states_c_stride = initial_states.stride(1);
|
| 203 |
+
params.initial_states_l_stride = initial_states.stride(2);
|
| 204 |
+
} else {
|
| 205 |
+
params.initial_states_ptr = nullptr;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
if (final_states_out_.has_value()) {
|
| 209 |
+
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
| 210 |
+
auto final_states = final_states_out_.value();
|
| 211 |
+
TORCH_CHECK(final_states.scalar_type() == input_type);
|
| 212 |
+
TORCH_CHECK(final_states.is_cuda());
|
| 213 |
+
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
| 214 |
+
TORCH_CHECK(final_states.stride(1) == 1);
|
| 215 |
+
params.final_states_ptr = final_states.data_ptr();
|
| 216 |
+
params.final_states_batch_stride = final_states.stride(0);
|
| 217 |
+
params.final_states_c_stride = final_states.stride(1);
|
| 218 |
+
params.final_states_l_stride = final_states.stride(2);
|
| 219 |
+
} else {
|
| 220 |
+
params.final_states_ptr = nullptr;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 224 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 225 |
+
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
| 226 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 227 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
| 228 |
+
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
|
| 229 |
+
if (!is_channel_last) {
|
| 230 |
+
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
| 231 |
+
} else {
|
| 232 |
+
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
| 233 |
+
}
|
| 234 |
+
});
|
| 235 |
+
});
|
| 236 |
+
return out;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
std::vector<at::Tensor>
|
| 240 |
+
causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
|
| 241 |
+
const c10::optional<at::Tensor> &bias_,
|
| 242 |
+
at::Tensor &dout,
|
| 243 |
+
const c10::optional<at::Tensor> &seq_idx_,
|
| 244 |
+
const c10::optional<at::Tensor> &initial_states_,
|
| 245 |
+
const c10::optional<at::Tensor> &dfinal_states_,
|
| 246 |
+
c10::optional<at::Tensor> &dx_,
|
| 247 |
+
bool return_dinitial_states,
|
| 248 |
+
bool silu_activation) {
|
| 249 |
+
auto input_type = x.scalar_type();
|
| 250 |
+
auto weight_type = weight.scalar_type();
|
| 251 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 252 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
| 253 |
+
|
| 254 |
+
TORCH_CHECK(x.is_cuda());
|
| 255 |
+
TORCH_CHECK(weight.is_cuda());
|
| 256 |
+
TORCH_CHECK(dout.is_cuda());
|
| 257 |
+
|
| 258 |
+
const auto sizes = x.sizes();
|
| 259 |
+
const int batch_size = sizes[0];
|
| 260 |
+
const int dim = sizes[1];
|
| 261 |
+
const int seqlen = sizes[2];
|
| 262 |
+
const int width = weight.size(-1);
|
| 263 |
+
|
| 264 |
+
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
| 265 |
+
|
| 266 |
+
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
| 267 |
+
CHECK_SHAPE(weight, dim, width);
|
| 268 |
+
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
| 269 |
+
|
| 270 |
+
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
| 271 |
+
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
| 272 |
+
if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
|
| 273 |
+
if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
|
| 274 |
+
|
| 275 |
+
if (is_channel_last) {
|
| 276 |
+
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
| 277 |
+
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
| 278 |
+
TORCH_CHECK(dout.stride(2) % 8 == 0 and dout.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (dout.stride(0) and dout.stride(2)) to be multiples of 8");
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
if (bias_.has_value()) {
|
| 282 |
+
auto bias = bias_.value();
|
| 283 |
+
TORCH_CHECK(bias.scalar_type() == weight_type);
|
| 284 |
+
TORCH_CHECK(bias.is_cuda());
|
| 285 |
+
TORCH_CHECK(bias.stride(-1) == 1);
|
| 286 |
+
CHECK_SHAPE(bias, dim);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if (seq_idx_.has_value()) {
|
| 290 |
+
TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout");
|
| 291 |
+
auto seq_idx = seq_idx_.value();
|
| 292 |
+
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
| 293 |
+
TORCH_CHECK(seq_idx.is_cuda());
|
| 294 |
+
TORCH_CHECK(seq_idx.is_contiguous());
|
| 295 |
+
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
at::Tensor dx;
|
| 299 |
+
if (dx_.has_value()) {
|
| 300 |
+
dx = dx_.value();
|
| 301 |
+
TORCH_CHECK(dx.scalar_type() == input_type);
|
| 302 |
+
TORCH_CHECK(dx.is_cuda());
|
| 303 |
+
CHECK_SHAPE(dx, batch_size, dim, seqlen);
|
| 304 |
+
if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
|
| 305 |
+
if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
|
| 306 |
+
} else {
|
| 307 |
+
dx = torch::empty_like(x);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 311 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 312 |
+
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
| 313 |
+
|
| 314 |
+
at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
|
| 315 |
+
at::Tensor dbias;
|
| 316 |
+
if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); }
|
| 317 |
+
|
| 318 |
+
ConvParamsBwd params;
|
| 319 |
+
set_conv_params_bwd(params, batch_size, dim, seqlen, width,
|
| 320 |
+
x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
| 321 |
+
dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr,
|
| 322 |
+
silu_activation);
|
| 323 |
+
|
| 324 |
+
if (seq_idx_.has_value()) {
|
| 325 |
+
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
| 326 |
+
} else {
|
| 327 |
+
params.seq_idx_ptr = nullptr;
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
if (initial_states_.has_value()) {
|
| 331 |
+
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
| 332 |
+
auto initial_states = initial_states_.value();
|
| 333 |
+
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
| 334 |
+
TORCH_CHECK(initial_states.is_cuda());
|
| 335 |
+
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
| 336 |
+
TORCH_CHECK(initial_states.stride(1) == 1);
|
| 337 |
+
params.initial_states_ptr = initial_states.data_ptr();
|
| 338 |
+
params.initial_states_batch_stride = initial_states.stride(0);
|
| 339 |
+
params.initial_states_c_stride = initial_states.stride(1);
|
| 340 |
+
params.initial_states_l_stride = initial_states.stride(2);
|
| 341 |
+
} else {
|
| 342 |
+
params.initial_states_ptr = nullptr;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
if (dfinal_states_.has_value()) {
|
| 346 |
+
TORCH_CHECK(is_channel_last, "dfinal_states is only supported for channel last layout");
|
| 347 |
+
auto dfinal_states = dfinal_states_.value();
|
| 348 |
+
TORCH_CHECK(dfinal_states.scalar_type() == input_type);
|
| 349 |
+
TORCH_CHECK(dfinal_states.is_cuda());
|
| 350 |
+
CHECK_SHAPE(dfinal_states, batch_size, dim, width - 1);
|
| 351 |
+
params.dfinal_states_ptr = dfinal_states.data_ptr();
|
| 352 |
+
params.dfinal_states_batch_stride = dfinal_states.stride(0);
|
| 353 |
+
params.dfinal_states_c_stride = dfinal_states.stride(1);
|
| 354 |
+
params.dfinal_states_l_stride = dfinal_states.stride(2);
|
| 355 |
+
} else {
|
| 356 |
+
params.dfinal_states_ptr = nullptr;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
at::Tensor dinitial_states;
|
| 360 |
+
if (return_dinitial_states) {
|
| 361 |
+
dinitial_states = torch::empty({batch_size, width - 1, dim}, x.options()).transpose(1, 2);
|
| 362 |
+
TORCH_CHECK(dinitial_states.stride(1) == 1);
|
| 363 |
+
params.dinitial_states_ptr = dinitial_states.data_ptr();
|
| 364 |
+
params.dinitial_states_batch_stride = dinitial_states.stride(0);
|
| 365 |
+
params.dinitial_states_c_stride = dinitial_states.stride(1);
|
| 366 |
+
params.dinitial_states_l_stride = dinitial_states.stride(2);
|
| 367 |
+
} else {
|
| 368 |
+
params.dinitial_states_ptr = nullptr;
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 372 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
|
| 373 |
+
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
|
| 374 |
+
if (!is_channel_last) {
|
| 375 |
+
causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
|
| 376 |
+
} else {
|
| 377 |
+
causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
|
| 378 |
+
}
|
| 379 |
+
});
|
| 380 |
+
});
|
| 381 |
+
return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias, dinitial_states};
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
at::Tensor
|
| 385 |
+
causal_conv1d_update(const at::Tensor &x,
|
| 386 |
+
const at::Tensor &conv_state,
|
| 387 |
+
const at::Tensor &weight,
|
| 388 |
+
const c10::optional<at::Tensor> &bias_,
|
| 389 |
+
bool silu_activation,
|
| 390 |
+
const c10::optional<at::Tensor> &cache_seqlens_
|
| 391 |
+
) {
|
| 392 |
+
auto input_type = x.scalar_type();
|
| 393 |
+
auto weight_type = weight.scalar_type();
|
| 394 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 395 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
| 396 |
+
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
| 397 |
+
|
| 398 |
+
TORCH_CHECK(x.is_cuda());
|
| 399 |
+
TORCH_CHECK(conv_state.is_cuda());
|
| 400 |
+
TORCH_CHECK(weight.is_cuda());
|
| 401 |
+
|
| 402 |
+
const auto sizes = x.sizes();
|
| 403 |
+
const int batch_size = sizes[0];
|
| 404 |
+
const int dim = sizes[1];
|
| 405 |
+
const int seqlen = sizes[2];
|
| 406 |
+
const int width = weight.size(-1);
|
| 407 |
+
const int conv_state_len = conv_state.size(2);
|
| 408 |
+
TORCH_CHECK(conv_state_len >= width - 1);
|
| 409 |
+
|
| 410 |
+
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
| 411 |
+
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
| 412 |
+
CHECK_SHAPE(weight, dim, width);
|
| 413 |
+
|
| 414 |
+
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
| 415 |
+
|
| 416 |
+
if (bias_.has_value()) {
|
| 417 |
+
auto bias = bias_.value();
|
| 418 |
+
TORCH_CHECK(bias.scalar_type() == weight_type);
|
| 419 |
+
TORCH_CHECK(bias.is_cuda());
|
| 420 |
+
TORCH_CHECK(bias.stride(-1) == 1);
|
| 421 |
+
CHECK_SHAPE(bias, dim);
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
at::Tensor out = torch::empty_like(x);
|
| 425 |
+
|
| 426 |
+
ConvParamsBase params;
|
| 427 |
+
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
| 428 |
+
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
| 429 |
+
silu_activation);
|
| 430 |
+
params.conv_state_ptr = conv_state.data_ptr();
|
| 431 |
+
params.conv_state_len = conv_state_len;
|
| 432 |
+
// All stride are in elements, not bytes.
|
| 433 |
+
params.conv_state_batch_stride = conv_state.stride(0);
|
| 434 |
+
params.conv_state_c_stride = conv_state.stride(1);
|
| 435 |
+
params.conv_state_l_stride = conv_state.stride(2);
|
| 436 |
+
|
| 437 |
+
if (cache_seqlens_.has_value()) {
|
| 438 |
+
auto cache_seqlens = cache_seqlens_.value();
|
| 439 |
+
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
| 440 |
+
TORCH_CHECK(cache_seqlens.is_cuda());
|
| 441 |
+
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
| 442 |
+
CHECK_SHAPE(cache_seqlens, batch_size);
|
| 443 |
+
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
| 444 |
+
} else {
|
| 445 |
+
params.cache_seqlens = nullptr;
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 449 |
+
// Cast to char to avoid compiler warning about narrowing
|
| 450 |
+
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
| 451 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 452 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
| 453 |
+
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
|
| 454 |
+
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
| 455 |
+
});
|
| 456 |
+
});
|
| 457 |
+
return out;
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 461 |
+
m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
|
| 462 |
+
m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
|
| 463 |
+
m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
|
| 464 |
+
}
|
causal-conv1d/csrc/causal_conv1d.h
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 8 |
+
|
| 9 |
+
struct ConvParamsBase {
|
| 10 |
+
using index_t = uint32_t;
|
| 11 |
+
|
| 12 |
+
int batch, dim, seqlen, width;
|
| 13 |
+
bool silu_activation;
|
| 14 |
+
|
| 15 |
+
index_t x_batch_stride;
|
| 16 |
+
index_t x_c_stride;
|
| 17 |
+
index_t x_l_stride;
|
| 18 |
+
index_t weight_c_stride;
|
| 19 |
+
index_t weight_width_stride;
|
| 20 |
+
index_t out_batch_stride;
|
| 21 |
+
index_t out_c_stride;
|
| 22 |
+
index_t out_l_stride;
|
| 23 |
+
|
| 24 |
+
int conv_state_len;
|
| 25 |
+
index_t conv_state_batch_stride;
|
| 26 |
+
index_t conv_state_c_stride;
|
| 27 |
+
index_t conv_state_l_stride;
|
| 28 |
+
|
| 29 |
+
// Common data pointers.
|
| 30 |
+
void *__restrict__ x_ptr;
|
| 31 |
+
void *__restrict__ weight_ptr;
|
| 32 |
+
void *__restrict__ bias_ptr;
|
| 33 |
+
void *__restrict__ out_ptr;
|
| 34 |
+
|
| 35 |
+
void *__restrict__ conv_state_ptr;
|
| 36 |
+
int32_t *__restrict__ cache_seqlens;
|
| 37 |
+
|
| 38 |
+
void *__restrict__ seq_idx_ptr;
|
| 39 |
+
|
| 40 |
+
// No __restrict__ since initial_states could be the same as final_states.
|
| 41 |
+
void * initial_states_ptr;
|
| 42 |
+
index_t initial_states_batch_stride;
|
| 43 |
+
index_t initial_states_l_stride;
|
| 44 |
+
index_t initial_states_c_stride;
|
| 45 |
+
|
| 46 |
+
void * final_states_ptr;
|
| 47 |
+
index_t final_states_batch_stride;
|
| 48 |
+
index_t final_states_l_stride;
|
| 49 |
+
index_t final_states_c_stride;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
struct ConvParamsBwd: public ConvParamsBase {
|
| 53 |
+
index_t dx_batch_stride;
|
| 54 |
+
index_t dx_c_stride;
|
| 55 |
+
index_t dx_l_stride;
|
| 56 |
+
index_t dweight_c_stride;
|
| 57 |
+
index_t dweight_width_stride;
|
| 58 |
+
index_t dout_batch_stride;
|
| 59 |
+
index_t dout_c_stride;
|
| 60 |
+
index_t dout_l_stride;
|
| 61 |
+
|
| 62 |
+
// Common data pointers.
|
| 63 |
+
void *__restrict__ dx_ptr;
|
| 64 |
+
void *__restrict__ dweight_ptr;
|
| 65 |
+
void *__restrict__ dbias_ptr;
|
| 66 |
+
void *__restrict__ dout_ptr;
|
| 67 |
+
|
| 68 |
+
void * dinitial_states_ptr;
|
| 69 |
+
index_t dinitial_states_batch_stride;
|
| 70 |
+
index_t dinitial_states_l_stride;
|
| 71 |
+
index_t dinitial_states_c_stride;
|
| 72 |
+
|
| 73 |
+
void * dfinal_states_ptr;
|
| 74 |
+
index_t dfinal_states_batch_stride;
|
| 75 |
+
index_t dfinal_states_l_stride;
|
| 76 |
+
index_t dfinal_states_c_stride;
|
| 77 |
+
};
|
causal-conv1d/csrc/causal_conv1d_bwd.cu
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
#include <c10/util/Half.h>
|
| 7 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 8 |
+
|
| 9 |
+
#ifndef USE_ROCM
|
| 10 |
+
#include <cub/block/block_load.cuh>
|
| 11 |
+
#include <cub/block/block_store.cuh>
|
| 12 |
+
#include <cub/block/block_reduce.cuh>
|
| 13 |
+
#else
|
| 14 |
+
#include <hipcub/hipcub.hpp>
|
| 15 |
+
namespace cub = hipcub;
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#include "causal_conv1d.h"
|
| 19 |
+
#include "causal_conv1d_common.h"
|
| 20 |
+
#include "static_switch.h"
|
| 21 |
+
|
| 22 |
+
template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 23 |
+
struct Causal_conv1d_bwd_kernel_traits {
|
| 24 |
+
using input_t = input_t_;
|
| 25 |
+
using weight_t = weight_t_;
|
| 26 |
+
static constexpr int kNThreads = kNThreads_;
|
| 27 |
+
static constexpr int kWidth = kWidth_;
|
| 28 |
+
static constexpr bool kSiluAct = kSiluAct_;
|
| 29 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 30 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 31 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 32 |
+
static_assert(kWidth <= kNElts);
|
| 33 |
+
// It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
|
| 34 |
+
// (since then we'd have 8 values of float, and each round we can exchange 4 floats).
|
| 35 |
+
static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
|
| 36 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 37 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 38 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 39 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
| 40 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 41 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
| 42 |
+
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
| 43 |
+
static constexpr int kSmemIOSize = kIsVecLoad
|
| 44 |
+
? 0
|
| 45 |
+
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
| 46 |
+
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
|
| 47 |
+
static constexpr int kSmemSize = custom_max({kSmemExchangeSize,
|
| 48 |
+
int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
template<typename Ktraits>
|
| 52 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 53 |
+
void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
|
| 54 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 55 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 56 |
+
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
| 57 |
+
static constexpr int kNElts = Ktraits::kNElts;
|
| 58 |
+
constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
|
| 59 |
+
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
| 60 |
+
using input_t = typename Ktraits::input_t;
|
| 61 |
+
using vec_t = typename Ktraits::vec_t;
|
| 62 |
+
using weight_t = typename Ktraits::weight_t;
|
| 63 |
+
|
| 64 |
+
// Shared memory.
|
| 65 |
+
extern __shared__ char smem_[];
|
| 66 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 67 |
+
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
| 68 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 69 |
+
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
| 70 |
+
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
| 71 |
+
vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
|
| 72 |
+
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 73 |
+
|
| 74 |
+
const int tidx = threadIdx.x;
|
| 75 |
+
const int batch_id = blockIdx.x;
|
| 76 |
+
const int dim_id = blockIdx.y;
|
| 77 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 78 |
+
+ dim_id * params.x_c_stride;
|
| 79 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
|
| 80 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
| 81 |
+
+ dim_id * params.dout_c_stride;
|
| 82 |
+
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
| 83 |
+
+ dim_id * params.dx_c_stride;
|
| 84 |
+
float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
|
| 85 |
+
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
|
| 86 |
+
|
| 87 |
+
// Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
|
| 88 |
+
if (tidx == 0) {
|
| 89 |
+
if constexpr (!kSiluAct) {
|
| 90 |
+
input_t zeros[kNElts] = {0};
|
| 91 |
+
smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
|
| 92 |
+
} else {
|
| 93 |
+
float zeros[kNElts] = {0};
|
| 94 |
+
#pragma unroll
|
| 95 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 96 |
+
smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
float weight_vals[kWidth];
|
| 102 |
+
#pragma unroll
|
| 103 |
+
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
|
| 104 |
+
|
| 105 |
+
float dweight_vals[kWidth] = {0};
|
| 106 |
+
float dbias_val = 0;
|
| 107 |
+
|
| 108 |
+
constexpr int kChunkSize = kNThreads * kNElts;
|
| 109 |
+
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
| 110 |
+
x += (n_chunks - 1) * kChunkSize;
|
| 111 |
+
dout += (n_chunks - 1) * kChunkSize;
|
| 112 |
+
dx += (n_chunks - 1) * kChunkSize;
|
| 113 |
+
for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
|
| 114 |
+
input_t x_vals_load[2 * kNElts] = {0};
|
| 115 |
+
input_t dout_vals_load[2 * kNElts] = {0};
|
| 116 |
+
if constexpr(kIsVecLoad) {
|
| 117 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 118 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 119 |
+
} else {
|
| 120 |
+
__syncthreads();
|
| 121 |
+
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
| 122 |
+
__syncthreads();
|
| 123 |
+
typename Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
|
| 124 |
+
}
|
| 125 |
+
float dout_vals[2 * kNElts], x_vals[2 * kNElts];
|
| 126 |
+
if constexpr (!kSiluAct) {
|
| 127 |
+
__syncthreads();
|
| 128 |
+
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
| 129 |
+
// the first elements of the next chunk.
|
| 130 |
+
if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
| 131 |
+
__syncthreads();
|
| 132 |
+
reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
|
| 133 |
+
__syncthreads();
|
| 134 |
+
// Now thread 0 can write the first elements of the current chunk.
|
| 135 |
+
if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
| 136 |
+
#pragma unroll
|
| 137 |
+
for (int i = 0; i < 2 * kNElts; ++i) {
|
| 138 |
+
dout_vals[i] = float(dout_vals_load[i]);
|
| 139 |
+
x_vals[i] = float(x_vals_load[i]);
|
| 140 |
+
}
|
| 141 |
+
} else {
|
| 142 |
+
if (tidx == 0 && chunk > 0) {
|
| 143 |
+
if constexpr(kIsVecLoad) {
|
| 144 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
|
| 145 |
+
} else {
|
| 146 |
+
#pragma unroll
|
| 147 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 148 |
+
if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
__syncthreads();
|
| 153 |
+
smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
|
| 154 |
+
__syncthreads();
|
| 155 |
+
if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
|
| 156 |
+
#pragma unroll
|
| 157 |
+
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
| 158 |
+
// Recompute the output
|
| 159 |
+
#pragma unroll
|
| 160 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 161 |
+
float out_val = bias_val;
|
| 162 |
+
#pragma unroll
|
| 163 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 164 |
+
out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
| 165 |
+
}
|
| 166 |
+
float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
|
| 167 |
+
dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
|
| 168 |
+
* (1.0f + out_val * (1.0f - out_sigmoid_val));
|
| 169 |
+
}
|
| 170 |
+
// Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
|
| 171 |
+
// if input_t is 16 bits (since then we'd have 8 values of float)
|
| 172 |
+
__syncthreads();
|
| 173 |
+
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
| 174 |
+
// the first elements of the next chunk.
|
| 175 |
+
if (tidx > 0) {
|
| 176 |
+
#pragma unroll
|
| 177 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 178 |
+
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
__syncthreads();
|
| 182 |
+
#pragma unroll
|
| 183 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 184 |
+
reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
|
| 185 |
+
= smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
|
| 186 |
+
}
|
| 187 |
+
__syncthreads();
|
| 188 |
+
// Now thread 0 can write the first elements of the current chunk.
|
| 189 |
+
if (tidx == 0) {
|
| 190 |
+
#pragma unroll
|
| 191 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 192 |
+
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
dout -= kChunkSize;
|
| 197 |
+
x -= kChunkSize;
|
| 198 |
+
|
| 199 |
+
#pragma unroll
|
| 200 |
+
for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
|
| 201 |
+
|
| 202 |
+
float dx_vals[kNElts] = {0};
|
| 203 |
+
#pragma unroll
|
| 204 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 205 |
+
#pragma unroll
|
| 206 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 207 |
+
dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
input_t dx_vals_store[kNElts];
|
| 212 |
+
#pragma unroll
|
| 213 |
+
for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
|
| 214 |
+
if constexpr(kIsVecLoad) {
|
| 215 |
+
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 216 |
+
} else {
|
| 217 |
+
typename Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
|
| 218 |
+
}
|
| 219 |
+
dx -= kChunkSize;
|
| 220 |
+
|
| 221 |
+
#pragma unroll
|
| 222 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 223 |
+
#pragma unroll
|
| 224 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 225 |
+
dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
#pragma unroll
|
| 231 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 232 |
+
__syncthreads();
|
| 233 |
+
dweight_vals[w] = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
|
| 234 |
+
if (tidx == 0) {
|
| 235 |
+
atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
if (params.bias_ptr != nullptr) {
|
| 239 |
+
__syncthreads();
|
| 240 |
+
dbias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
|
| 241 |
+
if (tidx == 0) {
|
| 242 |
+
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 248 |
+
void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 249 |
+
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
| 250 |
+
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
| 251 |
+
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
| 252 |
+
using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
|
| 253 |
+
constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 254 |
+
dim3 grid(params.batch, params.dim);
|
| 255 |
+
auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
|
| 256 |
+
|
| 257 |
+
if (kSmemSize >= 48 * 1024) {
|
| 258 |
+
#ifndef USE_ROCM
|
| 259 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 260 |
+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 261 |
+
#else
|
| 262 |
+
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
| 263 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 264 |
+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 265 |
+
std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
| 266 |
+
#endif
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 271 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 272 |
+
});
|
| 273 |
+
});
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
template<typename input_t, typename weight_t>
|
| 277 |
+
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 278 |
+
if (params.width == 2) {
|
| 279 |
+
causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 280 |
+
} else if (params.width == 3) {
|
| 281 |
+
causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 282 |
+
} else if (params.width == 4) {
|
| 283 |
+
causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 288 |
+
struct Causal_conv1d_channellast_bwd_kernel_traits {
|
| 289 |
+
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
| 290 |
+
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
| 291 |
+
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
| 292 |
+
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
| 293 |
+
using input_t = input_t_;
|
| 294 |
+
using weight_t = weight_t_;
|
| 295 |
+
static constexpr bool kSiluAct = kSiluAct_;
|
| 296 |
+
static constexpr int kNThreads = kNThreads_;
|
| 297 |
+
static_assert(kNThreads % 32 == 0);
|
| 298 |
+
static constexpr int kNWarps = kNThreads / 32;
|
| 299 |
+
static constexpr int kWidth = kWidth_;
|
| 300 |
+
static constexpr int kChunkSizeL = kChunkSizeL_;
|
| 301 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 302 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 303 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 304 |
+
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
| 305 |
+
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
| 306 |
+
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
| 307 |
+
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
| 308 |
+
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
| 309 |
+
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
| 310 |
+
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
| 311 |
+
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
| 312 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 313 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 314 |
+
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 315 |
+
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 316 |
+
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 317 |
+
// sizeof(typename BlockStoreT::TempStorage)});
|
| 318 |
+
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
| 319 |
+
};
|
| 320 |
+
|
| 321 |
+
template<typename Ktraits, bool kHasSeqIdx, bool kHasDfinalStates>
|
| 322 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 323 |
+
void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
|
| 324 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 325 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 326 |
+
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
| 327 |
+
constexpr int kNElts = Ktraits::kNElts;
|
| 328 |
+
constexpr int kNWarp = Ktraits::kNWarps;
|
| 329 |
+
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
| 330 |
+
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
| 331 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 332 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 333 |
+
using input_t = typename Ktraits::input_t;
|
| 334 |
+
using vec_t = typename Ktraits::vec_t;
|
| 335 |
+
using weight_t = typename Ktraits::weight_t;
|
| 336 |
+
|
| 337 |
+
// Shared memory.
|
| 338 |
+
__shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
| 339 |
+
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
| 340 |
+
|
| 341 |
+
const int batch_id = blockIdx.x;
|
| 342 |
+
const int chunk_l_id = blockIdx.y;
|
| 343 |
+
const int chunk_c_id = blockIdx.z;
|
| 344 |
+
const int tid = threadIdx.x;
|
| 345 |
+
const int l_idx = tid / kNThreadsPerC;
|
| 346 |
+
const int c_idx = tid % kNThreadsPerC;
|
| 347 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 348 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 349 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
| 350 |
+
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
| 351 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
| 352 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 353 |
+
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
| 354 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 355 |
+
float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
|
| 356 |
+
+ chunk_c_id * kChunkSizeC * params.dweight_c_stride;
|
| 357 |
+
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
| 358 |
+
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
| 359 |
+
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
| 360 |
+
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 361 |
+
input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
| 362 |
+
: reinterpret_cast<input_t *>(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 363 |
+
input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr
|
| 364 |
+
: reinterpret_cast<input_t *>(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC;
|
| 365 |
+
|
| 366 |
+
#pragma unroll
|
| 367 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 368 |
+
input_t dout_vals_load[kNElts] = {0};
|
| 369 |
+
input_t x_vals_load[kNElts] = {0};
|
| 370 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 371 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 372 |
+
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
|
| 373 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
| 374 |
+
}
|
| 375 |
+
reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
| 376 |
+
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 377 |
+
}
|
| 378 |
+
// Load the elements from the previous chunk or next chunk that are needed for convolution.
|
| 379 |
+
if (l_idx < kWidth - 1) {
|
| 380 |
+
input_t dout_vals_load[kNElts] = {0};
|
| 381 |
+
input_t x_vals_load[kNElts] = {0};
|
| 382 |
+
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
| 383 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 384 |
+
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
|
| 385 |
+
}
|
| 386 |
+
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
| 387 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
| 388 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 389 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
| 390 |
+
} else if (initial_states != nullptr
|
| 391 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
| 392 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 393 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
| 394 |
+
}
|
| 395 |
+
reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
| 396 |
+
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 397 |
+
}
|
| 398 |
+
// Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
|
| 399 |
+
if constexpr (kSiluAct) {
|
| 400 |
+
if (l_idx < kWidth - 1) {
|
| 401 |
+
input_t x_vals_load[kNElts] = {0};
|
| 402 |
+
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
| 403 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 404 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
|
| 405 |
+
}
|
| 406 |
+
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
__syncthreads();
|
| 411 |
+
|
| 412 |
+
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
| 413 |
+
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
| 414 |
+
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
| 415 |
+
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
| 416 |
+
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
| 417 |
+
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
| 418 |
+
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
| 419 |
+
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
| 420 |
+
static_assert(kNThreadsPerRow <= 32);
|
| 421 |
+
|
| 422 |
+
const int row_idx = tid / kNThreadsPerRow;
|
| 423 |
+
const int col_idx = tid % kNThreadsPerRow;
|
| 424 |
+
|
| 425 |
+
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
| 426 |
+
float weight_vals[kWidth] = {0};
|
| 427 |
+
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 428 |
+
#pragma unroll
|
| 429 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 430 |
+
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
float dout_vals[kLPerThread + kWidth - 1];
|
| 434 |
+
float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
|
| 435 |
+
#pragma unroll
|
| 436 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
| 437 |
+
dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
|
| 438 |
+
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1];
|
| 442 |
+
if constexpr (kHasSeqIdx) {
|
| 443 |
+
#pragma unroll
|
| 444 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
|
| 445 |
+
const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1);
|
| 446 |
+
seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
if constexpr (kSiluAct) { // Recompute the output
|
| 451 |
+
#pragma unroll
|
| 452 |
+
for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
|
| 453 |
+
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
| 454 |
+
}
|
| 455 |
+
#pragma unroll
|
| 456 |
+
for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
|
| 457 |
+
float out_val = bias_val;
|
| 458 |
+
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
| 459 |
+
#pragma unroll
|
| 460 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 461 |
+
if constexpr (!kHasSeqIdx) {
|
| 462 |
+
out_val += weight_vals[w] * x_vals[i + w];
|
| 463 |
+
} else {
|
| 464 |
+
out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
|
| 468 |
+
dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
float dweight_vals[kWidth] = {0};
|
| 473 |
+
SumOp<float> sum_op;
|
| 474 |
+
#pragma unroll
|
| 475 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 476 |
+
#pragma unroll
|
| 477 |
+
for (int i = 0; i < kLPerThread; ++i) {
|
| 478 |
+
if constexpr (!kHasSeqIdx) {
|
| 479 |
+
dweight_vals[w] += x_vals[i + w] * dout_vals[i];
|
| 480 |
+
} else {
|
| 481 |
+
dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f;
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
|
| 485 |
+
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 486 |
+
atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
if (params.bias_ptr != nullptr) {
|
| 491 |
+
float dbias_val = 0.f;
|
| 492 |
+
for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
|
| 493 |
+
dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
|
| 494 |
+
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 495 |
+
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
|
| 496 |
+
}
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
float dx_vals[kLPerThread] = {0};
|
| 500 |
+
#pragma unroll
|
| 501 |
+
for (int i = 0; i < kLPerThread; ++i) {
|
| 502 |
+
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
| 503 |
+
#pragma unroll
|
| 504 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 505 |
+
if constexpr (!kHasSeqIdx) {
|
| 506 |
+
dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w];
|
| 507 |
+
} else {
|
| 508 |
+
dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f;
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
// if (dfinal_states != nullptr) {
|
| 512 |
+
if constexpr (kHasDfinalStates) {
|
| 513 |
+
if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1
|
| 514 |
+
&& chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen
|
| 515 |
+
&& chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 516 |
+
dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
|
| 517 |
+
}
|
| 518 |
+
}
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
float dxinit_vals[kWidth - 1] = {0};
|
| 522 |
+
static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states
|
| 523 |
+
if (dinitial_states != nullptr && col_idx == 0) {
|
| 524 |
+
#pragma unroll
|
| 525 |
+
for (int i = 0; i < kWidth - 1; ++i) {
|
| 526 |
+
#pragma unroll
|
| 527 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 528 |
+
dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f;
|
| 529 |
+
}
|
| 530 |
+
// chunk_l_id must be 0 because dinitial_states != nullptr
|
| 531 |
+
// if (dfinal_states != nullptr) {
|
| 532 |
+
if constexpr (kHasDfinalStates) {
|
| 533 |
+
if (i >= params.seqlen) {
|
| 534 |
+
dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
|
| 535 |
+
}
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
__syncthreads();
|
| 541 |
+
#pragma unroll
|
| 542 |
+
for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
|
| 543 |
+
if (dinitial_states != nullptr && col_idx == 0) {
|
| 544 |
+
#pragma unroll
|
| 545 |
+
for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; }
|
| 546 |
+
}
|
| 547 |
+
__syncthreads();
|
| 548 |
+
|
| 549 |
+
#pragma unroll
|
| 550 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 551 |
+
input_t dx_vals_store[kNElts];
|
| 552 |
+
reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx];
|
| 553 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 554 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 555 |
+
*reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
|
| 556 |
+
}
|
| 557 |
+
}
|
| 558 |
+
if (dinitial_states != nullptr
|
| 559 |
+
&& l_idx < kWidth - 1
|
| 560 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 561 |
+
input_t dxinit_vals_store[kNElts];
|
| 562 |
+
reinterpret_cast<vec_t *>(dxinit_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx];
|
| 563 |
+
*reinterpret_cast<vec_t *>(dinitial_states) = reinterpret_cast<vec_t *>(dxinit_vals_store)[0];
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 569 |
+
void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 570 |
+
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
| 571 |
+
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
| 572 |
+
BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] {
|
| 573 |
+
BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] {
|
| 574 |
+
// kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger
|
| 575 |
+
static constexpr int kChunk = kChunkSizeL64 ? 64 : 128;
|
| 576 |
+
using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, kChunk, kSiluAct, true, input_t, weight_t>;
|
| 577 |
+
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 578 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 579 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 580 |
+
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
| 581 |
+
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
| 582 |
+
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
| 583 |
+
dim3 block(Ktraits::kNThreads);
|
| 584 |
+
auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits, kHasSeqIdx, kHasDfinalStates>;
|
| 585 |
+
// if (kSmemSize >= 48 * 1024) {
|
| 586 |
+
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 587 |
+
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 588 |
+
// }
|
| 589 |
+
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 590 |
+
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
| 591 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 592 |
+
});
|
| 593 |
+
});
|
| 594 |
+
});
|
| 595 |
+
});
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
template<typename input_t, typename weight_t>
|
| 599 |
+
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 600 |
+
if (params.width == 2) {
|
| 601 |
+
causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 602 |
+
} else if (params.width == 3) {
|
| 603 |
+
causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 604 |
+
} else if (params.width == 4) {
|
| 605 |
+
causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 606 |
+
}
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 610 |
+
template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 611 |
+
template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 612 |
+
template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 613 |
+
template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 614 |
+
template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 615 |
+
template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 616 |
+
template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 617 |
+
template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 618 |
+
|
| 619 |
+
template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 620 |
+
template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 621 |
+
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 622 |
+
template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 623 |
+
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 624 |
+
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 625 |
+
template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 626 |
+
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 627 |
+
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
causal-conv1d/csrc/causal_conv1d_common.h
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#ifndef USE_ROCM
|
| 8 |
+
#include <cuda_bf16.h>
|
| 9 |
+
|
| 10 |
+
template<typename T>
|
| 11 |
+
__device__ inline T shuffle_xor(T val, int offset) {
|
| 12 |
+
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
| 16 |
+
{
|
| 17 |
+
return std::max(ilist);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
template<typename T>
|
| 21 |
+
constexpr T constexpr_min(T a, T b) {
|
| 22 |
+
return std::min(a, b);
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
#else
|
| 26 |
+
#include <hip/hip_bf16.h>
|
| 27 |
+
|
| 28 |
+
template<typename T>
|
| 29 |
+
__device__ inline T shuffle_xor(T val, int offset) {
|
| 30 |
+
return __shfl_xor(val, offset);
|
| 31 |
+
}
|
| 32 |
+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
| 33 |
+
{
|
| 34 |
+
return *std::max_element(ilist.begin(), ilist.end());
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template<typename T>
|
| 38 |
+
constexpr T constexpr_min(T a, T b) {
|
| 39 |
+
return a < b ? a : b;
|
| 40 |
+
}
|
| 41 |
+
#endif
|
| 42 |
+
#include <cuda_fp16.h>
|
| 43 |
+
|
| 44 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
template<int BYTES> struct BytesToType {};
|
| 47 |
+
|
| 48 |
+
template<> struct BytesToType<16> {
|
| 49 |
+
using Type = uint4;
|
| 50 |
+
static_assert(sizeof(Type) == 16);
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
template<> struct BytesToType<8> {
|
| 54 |
+
using Type = uint64_t;
|
| 55 |
+
static_assert(sizeof(Type) == 8);
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
template<> struct BytesToType<4> {
|
| 59 |
+
using Type = uint32_t;
|
| 60 |
+
static_assert(sizeof(Type) == 4);
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
template<> struct BytesToType<2> {
|
| 64 |
+
using Type = uint16_t;
|
| 65 |
+
static_assert(sizeof(Type) == 2);
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
template<> struct BytesToType<1> {
|
| 69 |
+
using Type = uint8_t;
|
| 70 |
+
static_assert(sizeof(Type) == 1);
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 74 |
+
|
| 75 |
+
template<typename T>
|
| 76 |
+
struct SumOp {
|
| 77 |
+
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
template<int THREADS>
|
| 81 |
+
struct Allreduce {
|
| 82 |
+
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
| 83 |
+
template<typename T, typename Operator>
|
| 84 |
+
static __device__ inline T run(T x, Operator &op) {
|
| 85 |
+
constexpr int OFFSET = THREADS / 2;
|
| 86 |
+
x = op(x, shuffle_xor(x, OFFSET));
|
| 87 |
+
return Allreduce<OFFSET>::run(x, op);
|
| 88 |
+
}
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
template<>
|
| 92 |
+
struct Allreduce<2> {
|
| 93 |
+
template<typename T, typename Operator>
|
| 94 |
+
static __device__ inline T run(T x, Operator &op) {
|
| 95 |
+
x = op(x, shuffle_xor(x, 1));
|
| 96 |
+
return x;
|
| 97 |
+
}
|
| 98 |
+
};
|
causal-conv1d/csrc/causal_conv1d_fwd.cu
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
#include <c10/util/Half.h>
|
| 7 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 8 |
+
|
| 9 |
+
#ifndef USE_ROCM
|
| 10 |
+
#include <cub/block/block_load.cuh>
|
| 11 |
+
#include <cub/block/block_store.cuh>
|
| 12 |
+
#else
|
| 13 |
+
#include <hipcub/hipcub.hpp>
|
| 14 |
+
namespace cub = hipcub;
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#include "causal_conv1d.h"
|
| 18 |
+
#include "causal_conv1d_common.h"
|
| 19 |
+
#include "static_switch.h"
|
| 20 |
+
|
| 21 |
+
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 22 |
+
struct Causal_conv1d_fwd_kernel_traits {
|
| 23 |
+
using input_t = input_t_;
|
| 24 |
+
using weight_t = weight_t_;
|
| 25 |
+
static constexpr int kNThreads = kNThreads_;
|
| 26 |
+
static constexpr int kWidth = kWidth_;
|
| 27 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 28 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 29 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 30 |
+
static_assert(kWidth <= kNElts);
|
| 31 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 32 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 33 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 34 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
| 35 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 36 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
| 37 |
+
static constexpr int kSmemIOSize = kIsVecLoad
|
| 38 |
+
? 0
|
| 39 |
+
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
| 40 |
+
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
| 41 |
+
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
template<typename Ktraits>
|
| 45 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 46 |
+
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
| 47 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 48 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 49 |
+
constexpr int kNElts = Ktraits::kNElts;
|
| 50 |
+
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
| 51 |
+
using input_t = typename Ktraits::input_t;
|
| 52 |
+
using vec_t = typename Ktraits::vec_t;
|
| 53 |
+
using weight_t = typename Ktraits::weight_t;
|
| 54 |
+
|
| 55 |
+
// Shared memory.
|
| 56 |
+
extern __shared__ char smem_[];
|
| 57 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 58 |
+
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
| 59 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 60 |
+
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
| 61 |
+
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
| 62 |
+
|
| 63 |
+
const int tidx = threadIdx.x;
|
| 64 |
+
const int batch_id = blockIdx.x;
|
| 65 |
+
const int channel_id = blockIdx.y;
|
| 66 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 67 |
+
+ channel_id * params.x_c_stride;
|
| 68 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
| 69 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 70 |
+
+ channel_id * params.out_c_stride;
|
| 71 |
+
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
| 72 |
+
|
| 73 |
+
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
| 74 |
+
if (tidx == 0) {
|
| 75 |
+
input_t zeros[kNElts] = {0};
|
| 76 |
+
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
float weight_vals[kWidth];
|
| 80 |
+
#pragma unroll
|
| 81 |
+
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
| 82 |
+
|
| 83 |
+
constexpr int kChunkSize = kNThreads * kNElts;
|
| 84 |
+
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
| 85 |
+
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
| 86 |
+
input_t x_vals_load[2 * kNElts] = {0};
|
| 87 |
+
if constexpr(kIsVecLoad) {
|
| 88 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 89 |
+
} else {
|
| 90 |
+
__syncthreads();
|
| 91 |
+
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
| 92 |
+
}
|
| 93 |
+
x += kChunkSize;
|
| 94 |
+
__syncthreads();
|
| 95 |
+
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
| 96 |
+
// the last elements of the previous chunk.
|
| 97 |
+
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
| 98 |
+
__syncthreads();
|
| 99 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
| 100 |
+
__syncthreads();
|
| 101 |
+
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
| 102 |
+
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
| 103 |
+
|
| 104 |
+
float x_vals[2 * kNElts];
|
| 105 |
+
#pragma unroll
|
| 106 |
+
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
| 107 |
+
|
| 108 |
+
float out_vals[kNElts];
|
| 109 |
+
#pragma unroll
|
| 110 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 111 |
+
out_vals[i] = bias_val;
|
| 112 |
+
#pragma unroll
|
| 113 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 114 |
+
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
if (params.silu_activation) {
|
| 119 |
+
#pragma unroll
|
| 120 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 121 |
+
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
input_t out_vals_store[kNElts];
|
| 126 |
+
#pragma unroll
|
| 127 |
+
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
| 128 |
+
if constexpr(kIsVecLoad) {
|
| 129 |
+
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 130 |
+
} else {
|
| 131 |
+
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
| 132 |
+
}
|
| 133 |
+
out += kChunkSize;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 138 |
+
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 139 |
+
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
| 140 |
+
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
| 141 |
+
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
| 142 |
+
constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 143 |
+
dim3 grid(params.batch, params.dim);
|
| 144 |
+
|
| 145 |
+
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
| 146 |
+
|
| 147 |
+
if (kSmemSize >= 48 * 1024) {
|
| 148 |
+
#ifndef USE_ROCM
|
| 149 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 150 |
+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 151 |
+
#else
|
| 152 |
+
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
| 153 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 154 |
+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 155 |
+
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
| 156 |
+
#endif
|
| 157 |
+
}
|
| 158 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 159 |
+
|
| 160 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 161 |
+
});
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
template<typename input_t, typename weight_t>
|
| 165 |
+
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 166 |
+
if (params.width == 2) {
|
| 167 |
+
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 168 |
+
} else if (params.width == 3) {
|
| 169 |
+
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 170 |
+
} else if (params.width == 4) {
|
| 171 |
+
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 176 |
+
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
| 177 |
+
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
| 178 |
+
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
| 179 |
+
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
| 180 |
+
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
| 181 |
+
using input_t = input_t_;
|
| 182 |
+
using weight_t = weight_t_;
|
| 183 |
+
static constexpr int kNThreads = kNThreads_;
|
| 184 |
+
static_assert(kNThreads % 32 == 0);
|
| 185 |
+
static constexpr int kNWarps = kNThreads / 32;
|
| 186 |
+
static constexpr int kWidth = kWidth_;
|
| 187 |
+
static constexpr int kChunkSizeL = kChunkSizeL_;
|
| 188 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 189 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 190 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 191 |
+
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
| 192 |
+
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
| 193 |
+
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
| 194 |
+
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
| 195 |
+
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
| 196 |
+
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
| 197 |
+
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
| 198 |
+
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
| 199 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 200 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 201 |
+
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 202 |
+
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 203 |
+
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 204 |
+
// sizeof(typename BlockStoreT::TempStorage)});
|
| 205 |
+
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
template<typename Ktraits, bool kHasSeqIdx>
|
| 209 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 210 |
+
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
| 211 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 212 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 213 |
+
constexpr int kNElts = Ktraits::kNElts;
|
| 214 |
+
constexpr int kNWarp = Ktraits::kNWarps;
|
| 215 |
+
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
| 216 |
+
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
| 217 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 218 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 219 |
+
using input_t = typename Ktraits::input_t;
|
| 220 |
+
using vec_t = typename Ktraits::vec_t;
|
| 221 |
+
using weight_t = typename Ktraits::weight_t;
|
| 222 |
+
|
| 223 |
+
// Shared memory.
|
| 224 |
+
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
| 225 |
+
|
| 226 |
+
const int batch_id = blockIdx.x;
|
| 227 |
+
const int chunk_l_id = blockIdx.y;
|
| 228 |
+
const int chunk_c_id = blockIdx.z;
|
| 229 |
+
const int tid = threadIdx.x;
|
| 230 |
+
const int l_idx = tid / kNThreadsPerC;
|
| 231 |
+
const int c_idx = tid % kNThreadsPerC;
|
| 232 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 233 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 234 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
| 235 |
+
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
| 236 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 237 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 238 |
+
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
| 239 |
+
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
| 240 |
+
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
| 241 |
+
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 242 |
+
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
| 243 |
+
// from the previous L-chunk.
|
| 244 |
+
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
| 245 |
+
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 246 |
+
|
| 247 |
+
#pragma unroll
|
| 248 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 249 |
+
input_t x_vals_load[kNElts] = {0};
|
| 250 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 251 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 252 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
| 253 |
+
}
|
| 254 |
+
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 255 |
+
}
|
| 256 |
+
// Load the elements from the previous chunk that are needed for convolution.
|
| 257 |
+
if (l_idx < kWidth - 1) {
|
| 258 |
+
input_t x_vals_load[kNElts] = {0};
|
| 259 |
+
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
| 260 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
| 261 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 262 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
| 263 |
+
} else if (initial_states != nullptr
|
| 264 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
| 265 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 266 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
| 267 |
+
}
|
| 268 |
+
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
__syncthreads();
|
| 272 |
+
|
| 273 |
+
if (final_states != nullptr
|
| 274 |
+
&& l_idx < kWidth - 1
|
| 275 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 276 |
+
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
| 277 |
+
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
| 278 |
+
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
| 282 |
+
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
| 283 |
+
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
| 284 |
+
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
| 285 |
+
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
| 286 |
+
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
| 287 |
+
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
| 288 |
+
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
| 289 |
+
static_assert(kNThreadsPerRow <= 32);
|
| 290 |
+
|
| 291 |
+
const int row_idx = tid / kNThreadsPerRow;
|
| 292 |
+
const int col_idx = tid % kNThreadsPerRow;
|
| 293 |
+
|
| 294 |
+
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
| 295 |
+
float weight_vals[kWidth] = {0};
|
| 296 |
+
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 297 |
+
#pragma unroll
|
| 298 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 299 |
+
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
float x_vals[kWidth - 1 + kLPerThread];
|
| 303 |
+
#pragma unroll
|
| 304 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
| 305 |
+
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
| 306 |
+
}
|
| 307 |
+
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
| 308 |
+
if constexpr (kHasSeqIdx) {
|
| 309 |
+
#pragma unroll
|
| 310 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
| 311 |
+
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
float out_vals[kLPerThread];
|
| 316 |
+
#pragma unroll
|
| 317 |
+
for (int i = 0; i < kLPerThread; ++i) {
|
| 318 |
+
out_vals[i] = bias_val;
|
| 319 |
+
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
| 320 |
+
#pragma unroll
|
| 321 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 322 |
+
if constexpr (!kHasSeqIdx) {
|
| 323 |
+
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
| 324 |
+
} else {
|
| 325 |
+
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
__syncthreads();
|
| 332 |
+
#pragma unroll
|
| 333 |
+
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
| 334 |
+
__syncthreads();
|
| 335 |
+
|
| 336 |
+
#pragma unroll
|
| 337 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 338 |
+
input_t out_vals_store[kNElts];
|
| 339 |
+
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
| 340 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 341 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 342 |
+
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 349 |
+
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 350 |
+
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
| 351 |
+
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
| 352 |
+
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 353 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 354 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 355 |
+
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
| 356 |
+
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
| 357 |
+
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
| 358 |
+
dim3 block(Ktraits::kNThreads);
|
| 359 |
+
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
| 360 |
+
// if (kSmemSize >= 48 * 1024) {
|
| 361 |
+
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 362 |
+
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 363 |
+
// }
|
| 364 |
+
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 365 |
+
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
| 366 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 367 |
+
});
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
template<typename input_t, typename weight_t>
|
| 371 |
+
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 372 |
+
if (params.width == 2) {
|
| 373 |
+
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 374 |
+
} else if (params.width == 3) {
|
| 375 |
+
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 376 |
+
} else if (params.width == 4) {
|
| 377 |
+
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 382 |
+
template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 383 |
+
template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 384 |
+
template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 385 |
+
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 386 |
+
template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 387 |
+
template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 388 |
+
template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 389 |
+
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 390 |
+
|
| 391 |
+
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 392 |
+
template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 393 |
+
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 394 |
+
template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 395 |
+
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 396 |
+
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 397 |
+
template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 398 |
+
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 399 |
+
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
causal-conv1d/csrc/causal_conv1d_update.cu
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
#include <c10/util/Half.h>
|
| 7 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 8 |
+
|
| 9 |
+
#include "causal_conv1d.h"
|
| 10 |
+
#include "causal_conv1d_common.h"
|
| 11 |
+
#include "static_switch.h"
|
| 12 |
+
|
| 13 |
+
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
| 14 |
+
struct Causal_conv1d_update_kernel_traits {
|
| 15 |
+
using input_t = input_t_;
|
| 16 |
+
using weight_t = weight_t_;
|
| 17 |
+
static constexpr int kNThreads = kNThreads_;
|
| 18 |
+
static constexpr int kWidth = kWidth_;
|
| 19 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 20 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
template<typename Ktraits, bool kIsCircularBuffer>
|
| 24 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 25 |
+
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
| 26 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 27 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 28 |
+
using input_t = typename Ktraits::input_t;
|
| 29 |
+
using weight_t = typename Ktraits::weight_t;
|
| 30 |
+
|
| 31 |
+
const int tidx = threadIdx.x;
|
| 32 |
+
const int batch_id = blockIdx.x;
|
| 33 |
+
const int channel_id = blockIdx.y * kNThreads + tidx;
|
| 34 |
+
if (channel_id >= params.dim) return;
|
| 35 |
+
|
| 36 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 37 |
+
+ channel_id * params.x_c_stride;
|
| 38 |
+
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
|
| 39 |
+
+ channel_id * params.conv_state_c_stride;
|
| 40 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
| 41 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 42 |
+
+ channel_id * params.out_c_stride;
|
| 43 |
+
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
| 44 |
+
|
| 45 |
+
int state_len = params.conv_state_len;
|
| 46 |
+
int advance_len = params.seqlen;
|
| 47 |
+
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
| 48 |
+
int update_idx = cache_seqlen - (kWidth - 1);
|
| 49 |
+
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
| 50 |
+
|
| 51 |
+
float weight_vals[kWidth] = {0};
|
| 52 |
+
#pragma unroll
|
| 53 |
+
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
| 54 |
+
|
| 55 |
+
float x_vals[kWidth] = {0};
|
| 56 |
+
if constexpr (!kIsCircularBuffer) {
|
| 57 |
+
#pragma unroll 2
|
| 58 |
+
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
| 59 |
+
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
| 60 |
+
}
|
| 61 |
+
#pragma unroll
|
| 62 |
+
for (int i = 0; i < kWidth - 1; ++i) {
|
| 63 |
+
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
| 64 |
+
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
| 65 |
+
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
| 66 |
+
}
|
| 67 |
+
x_vals[i] = float(state_val);
|
| 68 |
+
}
|
| 69 |
+
} else {
|
| 70 |
+
#pragma unroll
|
| 71 |
+
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
| 72 |
+
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
| 73 |
+
x_vals[i] = float(state_val);
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
#pragma unroll 2
|
| 77 |
+
for (int i = 0; i < params.seqlen; ++i) {
|
| 78 |
+
input_t x_val = x[i * params.x_l_stride];
|
| 79 |
+
if constexpr (!kIsCircularBuffer) {
|
| 80 |
+
if (i < advance_len && state_len - advance_len + i >= 0) {
|
| 81 |
+
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
| 82 |
+
}
|
| 83 |
+
} else {
|
| 84 |
+
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
| 85 |
+
++update_idx;
|
| 86 |
+
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
| 87 |
+
}
|
| 88 |
+
x_vals[kWidth - 1] = float(x_val);
|
| 89 |
+
float out_val = bias_val;
|
| 90 |
+
#pragma unroll
|
| 91 |
+
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
| 92 |
+
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
| 93 |
+
out[i * params.out_l_stride] = input_t(out_val);
|
| 94 |
+
// Shift the input buffer by 1
|
| 95 |
+
#pragma unroll
|
| 96 |
+
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 101 |
+
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 102 |
+
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
| 103 |
+
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
| 104 |
+
auto kernel = params.cache_seqlens == nullptr
|
| 105 |
+
? &causal_conv1d_update_kernel<Ktraits, false>
|
| 106 |
+
: &causal_conv1d_update_kernel<Ktraits, true>;
|
| 107 |
+
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
| 108 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
template<typename input_t, typename weight_t>
|
| 112 |
+
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 113 |
+
if (params.width == 2) {
|
| 114 |
+
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
| 115 |
+
} else if (params.width == 3) {
|
| 116 |
+
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
| 117 |
+
} else if (params.width == 4) {
|
| 118 |
+
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 123 |
+
template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 124 |
+
template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 125 |
+
template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 126 |
+
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 127 |
+
template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 128 |
+
template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 129 |
+
template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 130 |
+
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
causal-conv1d/csrc/static_switch.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
| 2 |
+
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
/// @param COND - a boolean expression to switch by
|
| 7 |
+
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
| 8 |
+
/// @param ... - code to execute for true and false
|
| 9 |
+
///
|
| 10 |
+
/// Usage:
|
| 11 |
+
/// ```
|
| 12 |
+
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
| 13 |
+
/// some_function<BoolConst>(...);
|
| 14 |
+
/// });
|
| 15 |
+
/// ```
|
| 16 |
+
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
| 17 |
+
[&] { \
|
| 18 |
+
if (COND) { \
|
| 19 |
+
static constexpr bool CONST_NAME = true; \
|
| 20 |
+
return __VA_ARGS__(); \
|
| 21 |
+
} else { \
|
| 22 |
+
static constexpr bool CONST_NAME = false; \
|
| 23 |
+
return __VA_ARGS__(); \
|
| 24 |
+
} \
|
| 25 |
+
}()
|
causal-conv1d/dist/causal_conv1d-1.4.0-py3.9.egg
ADDED
|
Binary file (10 kB). View file
|
|
|
causal-conv1d/rocm_patch/rocm6_0.patch
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--- /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h 2023-12-12 20:11:48.000000000 +0000
|
| 2 |
+
+++ rocm_update_files/amd_hip_bf16.h 2024-05-20 17:40:26.983349079 +0000
|
| 3 |
+
@@ -137,7 +137,7 @@
|
| 4 |
+
* \ingroup HIP_INTRINSIC_BFLOAT16_CONV
|
| 5 |
+
* \brief Converts float to bfloat16
|
| 6 |
+
*/
|
| 7 |
+
-__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) {
|
| 8 |
+
+__HOST_DEVICE__ static inline __hip_bfloat16 __float2bfloat16(float f) {
|
| 9 |
+
__hip_bfloat16 ret;
|
| 10 |
+
union {
|
| 11 |
+
float fp32;
|
| 12 |
+
@@ -181,7 +181,7 @@
|
| 13 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
| 14 |
+
* \brief Converts and moves bfloat162 to float2
|
| 15 |
+
*/
|
| 16 |
+
-__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) {
|
| 17 |
+
+__HOST_DEVICE__ static inline float2 __bfloat1622float2(const __hip_bfloat162 a) {
|
| 18 |
+
return float2{__bfloat162float(a.x), __bfloat162float(a.y)};
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
@@ -209,7 +209,7 @@
|
| 22 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
| 23 |
+
* \brief Convert double to __hip_bfloat16
|
| 24 |
+
*/
|
| 25 |
+
-__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) {
|
| 26 |
+
+__HOST_DEVICE__ static inline __hip_bfloat16 __double2bfloat16(const double a) {
|
| 27 |
+
return __float2bfloat16((float)a);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
@@ -217,7 +217,7 @@
|
| 31 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
| 32 |
+
* \brief Convert float2 to __hip_bfloat162
|
| 33 |
+
*/
|
| 34 |
+
-__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
|
| 35 |
+
+__HOST_DEVICE__ static inline __hip_bfloat162 __float22bfloat162_rn(const float2 a) {
|
| 36 |
+
return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)};
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
@@ -247,7 +247,7 @@
|
| 40 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
| 41 |
+
* \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result
|
| 42 |
+
*/
|
| 43 |
+
-__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
|
| 44 |
+
+__HOST_DEVICE__ static inline float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }
|
| 45 |
+
|
| 46 |
+
/**
|
| 47 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
| 48 |
+
@@ -275,7 +275,7 @@
|
| 49 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
| 50 |
+
* \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result
|
| 51 |
+
*/
|
| 52 |
+
-__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
|
| 53 |
+
+__HOST_DEVICE__ static inline float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }
|
| 54 |
+
|
| 55 |
+
/**
|
| 56 |
+
* \ingroup HIP_INTRINSIC_BFLOAT162_CONV
|
causal-conv1d/setup.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import warnings
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import shutil
|
| 8 |
+
import ast
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from packaging.version import parse, Version
|
| 11 |
+
import platform
|
| 12 |
+
|
| 13 |
+
from setuptools import setup, find_packages
|
| 14 |
+
import subprocess
|
| 15 |
+
|
| 16 |
+
import urllib.request
|
| 17 |
+
import urllib.error
|
| 18 |
+
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, HIP_HOME
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
with open("README.md", "r", encoding="utf-8") as fh:
|
| 25 |
+
long_description = fh.read()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ninja build does not work unless include_dirs are abs path
|
| 29 |
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
| 30 |
+
|
| 31 |
+
PACKAGE_NAME = "causal_conv1d"
|
| 32 |
+
|
| 33 |
+
BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
|
| 34 |
+
|
| 35 |
+
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
| 36 |
+
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
| 37 |
+
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
|
| 38 |
+
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
|
| 39 |
+
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
|
| 40 |
+
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_platform():
|
| 44 |
+
"""
|
| 45 |
+
Returns the platform name as used in wheel filenames.
|
| 46 |
+
"""
|
| 47 |
+
if sys.platform.startswith("linux"):
|
| 48 |
+
return "linux_x86_64"
|
| 49 |
+
elif sys.platform == "darwin":
|
| 50 |
+
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
|
| 51 |
+
return f"macosx_{mac_version}_x86_64"
|
| 52 |
+
elif sys.platform == "win32":
|
| 53 |
+
return "win_amd64"
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_hip_version(rocm_dir):
|
| 61 |
+
|
| 62 |
+
hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
|
| 63 |
+
try:
|
| 64 |
+
raw_output = subprocess.check_output(
|
| 65 |
+
[hipcc_bin, "--version"], universal_newlines=True
|
| 66 |
+
)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(
|
| 69 |
+
f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
|
| 70 |
+
)
|
| 71 |
+
return None, None
|
| 72 |
+
|
| 73 |
+
for line in raw_output.split("\n"):
|
| 74 |
+
if "HIP version" in line:
|
| 75 |
+
rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly
|
| 76 |
+
return line, rocm_version
|
| 77 |
+
|
| 78 |
+
return None, None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_torch_hip_version():
|
| 82 |
+
if torch.version.hip:
|
| 83 |
+
return parse(torch.version.hip.split()[-1].replace("-", "+"))
|
| 84 |
+
else:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def check_if_hip_home_none(global_option: str) -> None:
|
| 89 |
+
|
| 90 |
+
if HIP_HOME is not None:
|
| 91 |
+
return
|
| 92 |
+
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
|
| 93 |
+
# in that case.
|
| 94 |
+
warnings.warn(
|
| 95 |
+
f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def check_if_cuda_home_none(global_option: str) -> None:
|
| 100 |
+
if CUDA_HOME is not None:
|
| 101 |
+
return
|
| 102 |
+
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
|
| 103 |
+
# in that case.
|
| 104 |
+
warnings.warn(
|
| 105 |
+
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
| 106 |
+
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
| 107 |
+
"only images whose names contain 'devel' will provide nvcc."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def append_nvcc_threads(nvcc_extra_args):
|
| 112 |
+
return nvcc_extra_args + ["--threads", "4"]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
cmdclass = {}
|
| 116 |
+
ext_modules = []
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
HIP_BUILD = bool(torch.version.hip)
|
| 120 |
+
|
| 121 |
+
if not SKIP_CUDA_BUILD:
|
| 122 |
+
|
| 123 |
+
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
| 124 |
+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
| 125 |
+
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
cc_flag = []
|
| 129 |
+
|
| 130 |
+
if HIP_BUILD:
|
| 131 |
+
check_if_hip_home_none(PACKAGE_NAME)
|
| 132 |
+
|
| 133 |
+
rocm_home = os.getenv("ROCM_PATH")
|
| 134 |
+
_, hip_version = get_hip_version(rocm_home)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if HIP_HOME is not None:
|
| 138 |
+
if hip_version < Version("6.0"):
|
| 139 |
+
raise RuntimeError(
|
| 140 |
+
f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. "
|
| 141 |
+
"Note: make sure HIP has a supported version by running hipcc --version."
|
| 142 |
+
)
|
| 143 |
+
if hip_version == Version("6.0"):
|
| 144 |
+
warnings.warn(
|
| 145 |
+
f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. "
|
| 146 |
+
"Refer to the README.md for detailed instructions.",
|
| 147 |
+
UserWarning
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
cc_flag.append("-DBUILD_PYTHON_PACKAGE")
|
| 151 |
+
|
| 152 |
+
else:
|
| 153 |
+
cc_flag.append("-gencode")
|
| 154 |
+
cc_flag.append("arch=compute_53,code=sm_53")
|
| 155 |
+
cc_flag.append("-gencode")
|
| 156 |
+
cc_flag.append("arch=compute_62,code=sm_62")
|
| 157 |
+
cc_flag.append("-gencode")
|
| 158 |
+
cc_flag.append("arch=compute_70,code=sm_70")
|
| 159 |
+
cc_flag.append("-gencode")
|
| 160 |
+
cc_flag.append("arch=compute_72,code=sm_72")
|
| 161 |
+
cc_flag.append("-gencode")
|
| 162 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
| 163 |
+
cc_flag.append("-gencode")
|
| 164 |
+
cc_flag.append("arch=compute_87,code=sm_87")
|
| 165 |
+
|
| 166 |
+
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
| 167 |
+
# torch._C._GLIBCXX_USE_CXX11_ABI
|
| 168 |
+
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
|
| 169 |
+
if FORCE_CXX11_ABI:
|
| 170 |
+
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if HIP_BUILD:
|
| 174 |
+
extra_compile_args = {
|
| 175 |
+
"cxx": ["-O3", "-std=c++17"],
|
| 176 |
+
}
|
| 177 |
+
else:
|
| 178 |
+
extra_compile_args = {
|
| 179 |
+
"cxx": ["-O3"],
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_package_version():
|
| 184 |
+
with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
|
| 185 |
+
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
|
| 186 |
+
public_version = ast.literal_eval(version_match.group(1))
|
| 187 |
+
local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
|
| 188 |
+
if local_version:
|
| 189 |
+
return f"{public_version}+{local_version}"
|
| 190 |
+
else:
|
| 191 |
+
return str(public_version)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_wheel_url():
|
| 195 |
+
|
| 196 |
+
# Determine the version numbers that will be used to determine the correct wheel
|
| 197 |
+
torch_version_raw = parse(torch.__version__)
|
| 198 |
+
|
| 199 |
+
if HIP_BUILD:
|
| 200 |
+
# We're using the HIP version used to build torch, not the one currently installed
|
| 201 |
+
torch_hip_version = get_torch_hip_version()
|
| 202 |
+
hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
|
| 203 |
+
|
| 204 |
+
gpu_compute_version = hip_version if HIP_BUILD else cuda_version
|
| 205 |
+
cuda_or_hip = "hip" if HIP_BUILD else "cu"
|
| 206 |
+
|
| 207 |
+
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
| 208 |
+
platform_name = get_platform()
|
| 209 |
+
causal_conv1d_version = get_package_version()
|
| 210 |
+
|
| 211 |
+
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
|
| 212 |
+
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
| 213 |
+
|
| 214 |
+
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
| 215 |
+
wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
| 216 |
+
|
| 217 |
+
wheel_url = BASE_WHEEL_URL.format(
|
| 218 |
+
tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
|
| 219 |
+
)
|
| 220 |
+
return wheel_url, wheel_filename
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class CachedWheelsCommand(_bdist_wheel):
|
| 224 |
+
"""
|
| 225 |
+
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
|
| 226 |
+
find an existing wheel (which is currently the case for all installs). We use
|
| 227 |
+
the environment parameters to detect whether there is already a pre-built version of a compatible
|
| 228 |
+
wheel available and short-circuits the standard full build pipeline.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def run(self):
|
| 232 |
+
if FORCE_BUILD:
|
| 233 |
+
return super().run()
|
| 234 |
+
|
| 235 |
+
wheel_url, wheel_filename = get_wheel_url()
|
| 236 |
+
print("Guessing wheel URL: ", wheel_url)
|
| 237 |
+
try:
|
| 238 |
+
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
| 239 |
+
|
| 240 |
+
# Make the archive
|
| 241 |
+
# Lifted from the root wheel processing command
|
| 242 |
+
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
|
| 243 |
+
if not os.path.exists(self.dist_dir):
|
| 244 |
+
os.makedirs(self.dist_dir)
|
| 245 |
+
|
| 246 |
+
impl_tag, abi_tag, plat_tag = self.get_tag()
|
| 247 |
+
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
|
| 248 |
+
|
| 249 |
+
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
| 250 |
+
print("Raw wheel path", wheel_path)
|
| 251 |
+
shutil.move(wheel_filename, wheel_path)
|
| 252 |
+
except urllib.error.HTTPError:
|
| 253 |
+
print("Precompiled wheel not found. Building from source...")
|
| 254 |
+
# If the wheel could not be downloaded, build from source
|
| 255 |
+
super().run()
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
setup(
|
| 259 |
+
name=PACKAGE_NAME,
|
| 260 |
+
version=get_package_version(),
|
| 261 |
+
packages=find_packages(
|
| 262 |
+
exclude=(
|
| 263 |
+
"build",
|
| 264 |
+
"csrc",
|
| 265 |
+
"include",
|
| 266 |
+
"tests",
|
| 267 |
+
"dist",
|
| 268 |
+
"docs",
|
| 269 |
+
"benchmarks",
|
| 270 |
+
"causal_conv1d.egg-info",
|
| 271 |
+
)
|
| 272 |
+
),
|
| 273 |
+
author="Tri Dao",
|
| 274 |
+
author_email="[email protected]",
|
| 275 |
+
description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
|
| 276 |
+
long_description=long_description,
|
| 277 |
+
long_description_content_type="text/markdown",
|
| 278 |
+
url="https://github.com/Dao-AILab/causal-conv1d",
|
| 279 |
+
classifiers=[
|
| 280 |
+
"Programming Language :: Python :: 3",
|
| 281 |
+
"License :: OSI Approved :: BSD License",
|
| 282 |
+
"Operating System :: Unix",
|
| 283 |
+
],
|
| 284 |
+
ext_modules=ext_modules,
|
| 285 |
+
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
|
| 286 |
+
if ext_modules
|
| 287 |
+
else {
|
| 288 |
+
"bdist_wheel": CachedWheelsCommand,
|
| 289 |
+
},
|
| 290 |
+
python_requires=">=3.8",
|
| 291 |
+
install_requires=[
|
| 292 |
+
"torch",
|
| 293 |
+
"packaging",
|
| 294 |
+
"ninja",
|
| 295 |
+
],
|
| 296 |
+
)
|
causal-conv1d/tests/test_causal_conv1d.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
|
| 13 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
|
| 14 |
+
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states, causal_conv1d_varlen_states_ref
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.mark.parametrize("return_final_states", [False, True])
|
| 18 |
+
# @pytest.mark.parametrize("return_final_states", [True])
|
| 19 |
+
@pytest.mark.parametrize("has_initial_states", [False, True])
|
| 20 |
+
# @pytest.mark.parametrize("has_initial_states", [False])
|
| 21 |
+
@pytest.mark.parametrize("channel_last", [False, True])
|
| 22 |
+
# @pytest.mark.parametrize('channel_last', [True])
|
| 23 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 24 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 25 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 26 |
+
# @pytest.mark.parametrize('silu_activation', [True])
|
| 27 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
| 28 |
+
# @pytest.mark.parametrize('has_bias', [True])
|
| 29 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 30 |
+
# @pytest.mark.parametrize('width', [3])
|
| 31 |
+
@pytest.mark.parametrize(
|
| 32 |
+
"seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
| 33 |
+
)
|
| 34 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
| 35 |
+
# @pytest.mark.parametrize('seqlen', [128])
|
| 36 |
+
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
| 37 |
+
# @pytest.mark.parametrize('dim', [64])
|
| 38 |
+
def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states):
|
| 39 |
+
if not channel_last and (has_initial_states or return_final_states):
|
| 40 |
+
pytest.skip("Only channel_last support initial_states or return_final_states")
|
| 41 |
+
device = "cuda"
|
| 42 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 43 |
+
if itype == torch.bfloat16:
|
| 44 |
+
rtol, atol = 1e-2, 5e-2
|
| 45 |
+
rtolw, atolw = (1e-3, 1e-3)
|
| 46 |
+
# set seed
|
| 47 |
+
torch.random.manual_seed(0)
|
| 48 |
+
batch = 2
|
| 49 |
+
# batch = 1
|
| 50 |
+
if not channel_last:
|
| 51 |
+
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
| 52 |
+
else:
|
| 53 |
+
x = rearrange(
|
| 54 |
+
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
| 55 |
+
).requires_grad_()
|
| 56 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 57 |
+
if has_bias:
|
| 58 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 59 |
+
else:
|
| 60 |
+
bias = None
|
| 61 |
+
if has_initial_states:
|
| 62 |
+
initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_()
|
| 63 |
+
else:
|
| 64 |
+
initial_states = None
|
| 65 |
+
x_ref = x.detach().clone().requires_grad_()
|
| 66 |
+
weight_ref = weight.detach().clone().requires_grad_()
|
| 67 |
+
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
| 68 |
+
initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None
|
| 69 |
+
activation = None if not silu_activation else "silu"
|
| 70 |
+
out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states,
|
| 71 |
+
activation=activation)
|
| 72 |
+
out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation)
|
| 73 |
+
if return_final_states:
|
| 74 |
+
out, final_states = out
|
| 75 |
+
out_ref, final_states_ref = out_ref
|
| 76 |
+
print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}")
|
| 77 |
+
print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}")
|
| 78 |
+
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
|
| 79 |
+
|
| 80 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 81 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 82 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 83 |
+
|
| 84 |
+
if return_final_states:
|
| 85 |
+
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
|
| 86 |
+
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
|
| 87 |
+
|
| 88 |
+
g = torch.randn_like(out)
|
| 89 |
+
out.backward(g)
|
| 90 |
+
out_ref.backward(g)
|
| 91 |
+
|
| 92 |
+
print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
|
| 93 |
+
print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
|
| 94 |
+
if has_bias:
|
| 95 |
+
print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
|
| 96 |
+
if has_initial_states:
|
| 97 |
+
print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}")
|
| 98 |
+
|
| 99 |
+
assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
| 100 |
+
assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
|
| 101 |
+
if has_bias:
|
| 102 |
+
assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
|
| 103 |
+
if has_initial_states:
|
| 104 |
+
assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 108 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 109 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 110 |
+
# @pytest.mark.parametrize('silu_activation', [True])
|
| 111 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
| 112 |
+
# @pytest.mark.parametrize('has_bias', [True])
|
| 113 |
+
@pytest.mark.parametrize("has_cache_seqlens", [False, True])
|
| 114 |
+
# @pytest.mark.parametrize('has_cache_seqlens', [True])
|
| 115 |
+
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
| 116 |
+
# @pytest.mark.parametrize('seqlen', [4])
|
| 117 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 118 |
+
# @pytest.mark.parametrize('width', [4])
|
| 119 |
+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
| 120 |
+
# @pytest.mark.parametrize("dim", [2048])
|
| 121 |
+
def test_causal_conv1d_update(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
|
| 122 |
+
device = "cuda"
|
| 123 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 124 |
+
if itype == torch.bfloat16:
|
| 125 |
+
rtol, atol = 1e-2, 5e-2
|
| 126 |
+
rtolw, atolw = (1e-3, 1e-3)
|
| 127 |
+
# set seed
|
| 128 |
+
torch.random.manual_seed(0)
|
| 129 |
+
batch = 64
|
| 130 |
+
# batch = 1
|
| 131 |
+
# dim = 64
|
| 132 |
+
x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 133 |
+
state_len = torch.randint(width - 1, width + 10, (1,)).item()
|
| 134 |
+
conv_state = torch.randn(batch, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 135 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 136 |
+
if has_bias:
|
| 137 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 138 |
+
else:
|
| 139 |
+
bias = None
|
| 140 |
+
conv_state_ref = conv_state.detach().clone()
|
| 141 |
+
activation = None if not silu_activation else "silu"
|
| 142 |
+
cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
|
| 143 |
+
if has_cache_seqlens else None)
|
| 144 |
+
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
| 145 |
+
out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
| 146 |
+
|
| 147 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 148 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 149 |
+
assert torch.equal(conv_state, conv_state_ref)
|
| 150 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 154 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 155 |
+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
| 156 |
+
# @pytest.mark.parametrize("dim", [2048])
|
| 157 |
+
def test_causal_conv1d_get_states(dim, itype):
|
| 158 |
+
device = "cuda"
|
| 159 |
+
# set seed
|
| 160 |
+
torch.random.manual_seed(0)
|
| 161 |
+
seqlens = torch.randint(1, 32, (100,), device=device)
|
| 162 |
+
total_seqlen = seqlens.sum().item()
|
| 163 |
+
x = torch.randn(total_seqlen, dim, device=device, dtype=itype)
|
| 164 |
+
cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
|
| 165 |
+
state_len = 20
|
| 166 |
+
out = causal_conv1d_varlen_states(x, cu_seqlens, state_len)
|
| 167 |
+
out_ref = causal_conv1d_varlen_states_ref(x, cu_seqlens, state_len)
|
| 168 |
+
assert torch.equal(out, out_ref)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# @pytest.mark.parametrize("channel_last", [False, True])
|
| 172 |
+
@pytest.mark.parametrize('channel_last', [True])
|
| 173 |
+
# @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 174 |
+
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
| 175 |
+
# @pytest.mark.parametrize("silu_activation", [False, True])
|
| 176 |
+
@pytest.mark.parametrize('silu_activation', [True])
|
| 177 |
+
# @pytest.mark.parametrize("has_bias", [False, True])
|
| 178 |
+
@pytest.mark.parametrize('has_bias', [True])
|
| 179 |
+
# @pytest.mark.parametrize("width", [2, 3, 4])
|
| 180 |
+
@pytest.mark.parametrize('width', [4])
|
| 181 |
+
@pytest.mark.parametrize(
|
| 182 |
+
# "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
| 183 |
+
"seqlen", [2048]
|
| 184 |
+
)
|
| 185 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
| 186 |
+
# @pytest.mark.parametrize('seqlen', [128])
|
| 187 |
+
def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
|
| 188 |
+
device = "cuda"
|
| 189 |
+
# set seed
|
| 190 |
+
torch.random.manual_seed(0)
|
| 191 |
+
batch = 2
|
| 192 |
+
# batch = 1
|
| 193 |
+
dim = 4096 + 32 # Try dim not divisible by 64
|
| 194 |
+
# dim = 64
|
| 195 |
+
if not channel_last:
|
| 196 |
+
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
| 197 |
+
else:
|
| 198 |
+
x = rearrange(
|
| 199 |
+
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
| 200 |
+
).requires_grad_()
|
| 201 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 202 |
+
if has_bias:
|
| 203 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 204 |
+
else:
|
| 205 |
+
bias = None
|
| 206 |
+
activation = None if not silu_activation else "silu"
|
| 207 |
+
out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
|
| 208 |
+
g = torch.randn_like(out0)
|
| 209 |
+
dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
|
| 210 |
+
dw_atol = 1e-4
|
| 211 |
+
db_atol = 1e-4
|
| 212 |
+
|
| 213 |
+
for i in range(10000):
|
| 214 |
+
out = causal_conv1d_fn(x, weight, bias, activation=activation)
|
| 215 |
+
dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
|
| 216 |
+
dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
|
| 217 |
+
# if not dw_equal:
|
| 218 |
+
# breakpoint()
|
| 219 |
+
if has_bias:
|
| 220 |
+
db_equal = torch.allclose(db, db0, atol=db_atol)
|
| 221 |
+
# if not db_equal:
|
| 222 |
+
# breakpoint()
|
| 223 |
+
assert torch.equal(out, out0)
|
| 224 |
+
assert torch.equal(dx, dx0)
|
| 225 |
+
assert dw_equal
|
| 226 |
+
if has_bias:
|
| 227 |
+
assert dw_equal
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 231 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 232 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 233 |
+
# @pytest.mark.parametrize('silu_activation', [False])
|
| 234 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
| 235 |
+
# @pytest.mark.parametrize('has_bias', [False])
|
| 236 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 237 |
+
# @pytest.mark.parametrize('width', [2])
|
| 238 |
+
@pytest.mark.parametrize(
|
| 239 |
+
"seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
| 240 |
+
)
|
| 241 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
| 242 |
+
# @pytest.mark.parametrize('seqlen', [2048])
|
| 243 |
+
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
| 244 |
+
# @pytest.mark.parametrize('dim', [64])
|
| 245 |
+
def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype):
|
| 246 |
+
device = "cuda"
|
| 247 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 248 |
+
if itype == torch.bfloat16:
|
| 249 |
+
rtol, atol = 1e-2, 5e-2
|
| 250 |
+
rtolw, atolw = (1e-3, 1e-3)
|
| 251 |
+
# set seed
|
| 252 |
+
torch.random.manual_seed(seqlen + dim + width)
|
| 253 |
+
batch = 3
|
| 254 |
+
seqlens = []
|
| 255 |
+
for b in range(batch):
|
| 256 |
+
nsplits = torch.randint(1, 5, (1,)).item()
|
| 257 |
+
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
| 258 |
+
seqlens.append(torch.diff(torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist())
|
| 259 |
+
assert sum(seqlens[-1]) == seqlen
|
| 260 |
+
assert all(s > 0 for s in seqlens[-1])
|
| 261 |
+
# Only support channel_last
|
| 262 |
+
x = rearrange(
|
| 263 |
+
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
| 264 |
+
).requires_grad_()
|
| 265 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 266 |
+
if has_bias:
|
| 267 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 268 |
+
else:
|
| 269 |
+
bias = None
|
| 270 |
+
seq_idx = torch.stack([torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(sl)], dim=0)
|
| 271 |
+
for sl in seqlens], dim=0)
|
| 272 |
+
x_ref = x.detach().clone().requires_grad_()
|
| 273 |
+
weight_ref = weight.detach().clone().requires_grad_()
|
| 274 |
+
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
| 275 |
+
activation = None if not silu_activation else "silu"
|
| 276 |
+
out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation)
|
| 277 |
+
out_ref = []
|
| 278 |
+
for b in range(batch):
|
| 279 |
+
out_ref_b = []
|
| 280 |
+
for x_s in torch.split(x_ref[[b]], seqlens[b], dim=2):
|
| 281 |
+
out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation))
|
| 282 |
+
out_ref.append(torch.cat(out_ref_b, dim=2))
|
| 283 |
+
out_ref = torch.cat(out_ref, dim=0)
|
| 284 |
+
|
| 285 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 286 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 287 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 288 |
+
|
| 289 |
+
g = torch.randn_like(out)
|
| 290 |
+
out_ref.backward(g)
|
| 291 |
+
out.backward(g)
|
| 292 |
+
|
| 293 |
+
print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
|
| 294 |
+
print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
|
| 295 |
+
if has_bias:
|
| 296 |
+
print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
|
| 297 |
+
|
| 298 |
+
assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
| 299 |
+
assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
|
| 300 |
+
if has_bias:
|
| 301 |
+
assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
|