BechirTrabelsi1 commited on
Commit
25b4ce2
·
verified ·
1 Parent(s): a47eab0

Training in progress, step 500

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. AutoAWQ_kernels/.github/workflows/build.yaml +232 -0
  2. AutoAWQ_kernels/.github/workflows/scripts/github_create_release.js +17 -0
  3. AutoAWQ_kernels/.gitignore +164 -0
  4. AutoAWQ_kernels/LICENSE +21 -0
  5. AutoAWQ_kernels/README.md +42 -0
  6. AutoAWQ_kernels/awq_ext/attention/cuda_bf16_fallbacks.cuh +257 -0
  7. AutoAWQ_kernels/awq_ext/attention/cuda_bf16_wrapper.h +23 -0
  8. AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention.cu +152 -0
  9. AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention.h +184 -0
  10. AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention_template.hpp +1608 -0
  11. AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention_utils.h +1786 -0
  12. AutoAWQ_kernels/awq_ext/attention/ft_attention.cpp +182 -0
  13. AutoAWQ_kernels/awq_ext/attention/ft_attention.h +15 -0
  14. AutoAWQ_kernels/awq_ext/exllama/cu_compat.cuh +58 -0
  15. AutoAWQ_kernels/awq_ext/exllama/cuda_buffers.cu +75 -0
  16. AutoAWQ_kernels/awq_ext/exllama/cuda_buffers.cuh +55 -0
  17. AutoAWQ_kernels/awq_ext/exllama/cuda_func/column_remap.cu +63 -0
  18. AutoAWQ_kernels/awq_ext/exllama/cuda_func/column_remap.cuh +19 -0
  19. AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matmul.cu +260 -0
  20. AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matmul.cuh +43 -0
  21. AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matrix.cu +227 -0
  22. AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matrix.cuh +53 -0
  23. AutoAWQ_kernels/awq_ext/exllama/exllama_ext.cpp +260 -0
  24. AutoAWQ_kernels/awq_ext/exllama/hip_compat.cuh +51 -0
  25. AutoAWQ_kernels/awq_ext/exllama/matrix.cuh +294 -0
  26. AutoAWQ_kernels/awq_ext/exllama/tuning.h +13 -0
  27. AutoAWQ_kernels/awq_ext/exllama/util.cuh +33 -0
  28. AutoAWQ_kernels/awq_ext/exllamav2/config.h +13 -0
  29. AutoAWQ_kernels/awq_ext/exllamav2/cpp/util.h +12 -0
  30. AutoAWQ_kernels/awq_ext/exllamav2/cuda/compat.cuh +56 -0
  31. AutoAWQ_kernels/awq_ext/exllamav2/cuda/compat_gemm.cuh +38 -0
  32. AutoAWQ_kernels/awq_ext/exllamav2/cuda/matrix_view.cuh +121 -0
  33. AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm.cu +211 -0
  34. AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm.cuh +33 -0
  35. AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm_kernel.cuh +487 -0
  36. AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm_kernel_gptq.cuh +219 -0
  37. AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_matrix.cu +623 -0
  38. AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_matrix.cuh +73 -0
  39. AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_2.cuh +103 -0
  40. AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_3.cuh +169 -0
  41. AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_4.cuh +227 -0
  42. AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_5.cuh +207 -0
  43. AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_6.cuh +44 -0
  44. AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_8.cuh +38 -0
  45. AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_util.cuh +51 -0
  46. AutoAWQ_kernels/awq_ext/exllamav2/cuda/util.cuh +42 -0
  47. AutoAWQ_kernels/awq_ext/exllamav2/ext.cpp +134 -0
  48. AutoAWQ_kernels/awq_ext/layernorm/layernorm.cu +113 -0
  49. AutoAWQ_kernels/awq_ext/layernorm/layernorm.h +3 -0
  50. AutoAWQ_kernels/awq_ext/layernorm/reduction.cuh +82 -0
AutoAWQ_kernels/.github/workflows/build.yaml ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build AutoAWQ Wheels with CUDA
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*"
7
+
8
+ jobs:
9
+ release:
10
+ # Retrieve tag and create release
11
+ name: Create Release
12
+ runs-on: ubuntu-latest
13
+ outputs:
14
+ upload_url: ${{ steps.create_release.outputs.upload_url }}
15
+ steps:
16
+ - name: Checkout
17
+ uses: actions/checkout@v3
18
+
19
+ - name: Extract branch info
20
+ shell: bash
21
+ run: |
22
+ echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
23
+
24
+ - name: Create Release
25
+ id: create_release
26
+ uses: "actions/github-script@v6"
27
+ env:
28
+ RELEASE_TAG: ${{ env.release_tag }}
29
+ with:
30
+ github-token: "${{ secrets.GITHUB_TOKEN }}"
31
+ script: |
32
+ const script = require('.github/workflows/scripts/github_create_release.js')
33
+ await script(github, context, core)
34
+
35
+ build_cuda_wheels:
36
+ name: Build AWQ with CUDA
37
+ runs-on: ${{ matrix.os }}
38
+ needs: release
39
+
40
+ strategy:
41
+ matrix:
42
+ os: [ubuntu-20.04, windows-latest]
43
+ pyver: ["3.8", "3.9", "3.10", "3.11", "3.12"]
44
+ cuda: ["11.8.0", "12.1.1"]
45
+ defaults:
46
+ run:
47
+ shell: pwsh
48
+ env:
49
+ PYPI_CUDA_VERSION: "12.1.1"
50
+ CUDA_VERSION: ${{ matrix.cuda }}
51
+
52
+ steps:
53
+ - name: Free Disk Space
54
+ uses: jlumbroso/[email protected]
55
+ if: runner.os == 'Linux'
56
+ with:
57
+ tool-cache: false
58
+ android: true
59
+ dotnet: true
60
+ haskell: true
61
+ large-packages: false
62
+ docker-images: true
63
+ swap-storage: false
64
+
65
+ - uses: actions/checkout@v3
66
+
67
+ - uses: actions/setup-python@v3
68
+ with:
69
+ python-version: ${{ matrix.pyver }}
70
+
71
+ - name: Setup Mamba
72
+ uses: conda-incubator/[email protected]
73
+ with:
74
+ activate-environment: "build"
75
+ python-version: ${{ matrix.pyver }}
76
+ miniforge-variant: Mambaforge
77
+ miniforge-version: latest
78
+ use-mamba: true
79
+ add-pip-as-python-dependency: true
80
+ auto-activate-base: false
81
+
82
+ - name: Install Dependencies
83
+ run: |
84
+ # Install CUDA toolkit
85
+ mamba install -y 'cuda' -c "nvidia/label/cuda-${env:CUDA_VERSION}"
86
+
87
+ # Env variables
88
+ $env:CUDA_PATH = $env:CONDA_PREFIX
89
+ $env:CUDA_HOME = $env:CONDA_PREFIX
90
+
91
+ # Install torch
92
+ $cudaVersion = $env:CUDA_VERSION.Replace('.', '')
93
+ $cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1)
94
+ $pytorchVersion = "torch==2.3.1"
95
+ python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch
96
+ python -m pip install build setuptools wheel ninja
97
+
98
+ # Print version information
99
+ python --version
100
+ python -c "import torch; print('PyTorch:', torch.__version__)"
101
+ python -c "import torch; print('CUDA:', torch.version.cuda)"
102
+ python -c "import os; print('CUDA_HOME:', os.getenv('CUDA_HOME', None))"
103
+ python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
104
+
105
+ - name: Build Wheel
106
+ run: |
107
+ $env:CUDA_PATH = $env:CONDA_PREFIX
108
+ $env:CUDA_HOME = $env:CONDA_PREFIX
109
+
110
+ # Only add +cu118 to wheel if not releasing on PyPi
111
+ if ( $env:CUDA_VERSION -eq $env:PYPI_CUDA_VERSION ){
112
+ $env:PYPI_BUILD = 1
113
+ }
114
+
115
+ python setup.py sdist bdist_wheel
116
+
117
+ - name: Upload Assets
118
+ uses: shogo82148/actions-upload-release-asset@v1
119
+ with:
120
+ upload_url: ${{ needs.release.outputs.upload_url }}
121
+ asset_path: ./dist/*.whl
122
+
123
+ build_rocm_wheels:
124
+ name: Build AWQ with ROCm
125
+ runs-on: ${{ matrix.os }}
126
+ needs: release
127
+
128
+ strategy:
129
+ matrix:
130
+ os: [ubuntu-20.04]
131
+ python: ["3.8", "3.9", "3.10", "3.11"]
132
+ rocm: ["5.6.1", "5.7.1"] # we build only for rocm5.6 & 5.7 to match PyTorch 2.1.0 and PyTorch 2.2 nightly
133
+ defaults:
134
+ run:
135
+ shell: bash
136
+ env:
137
+ ROCM_VERSION: ${{ matrix.rocm }}
138
+
139
+ steps:
140
+ - uses: actions/checkout@v3
141
+
142
+ - name: Free Disk Space
143
+ run: |
144
+ df -h
145
+ echo "Removing large packages"
146
+ sudo apt-get remove -y '^dotnet-.*'
147
+ sudo apt-get remove -y 'php.*'
148
+ sudo apt-get remove -y azure-cli google-chrome-stable firefox powershell mono-devel
149
+ df -h
150
+ sudo apt-get autoremove -y >/dev/null 2>&1
151
+ sudo apt-get clean
152
+ sudo apt-get autoremove -y >/dev/null 2>&1
153
+ sudo apt-get autoclean -y >/dev/null 2>&1
154
+ df -h
155
+ echo "https://github.com/actions/virtual-environments/issues/709"
156
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
157
+ df -h
158
+ echo "remove big /usr/local"
159
+ sudo rm -rf "/usr/local/share/boost"
160
+ sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
161
+ df -h
162
+ sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1
163
+ sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1
164
+ sudo rm -rf /usr/share/swift > /dev/null 2>&1
165
+ df -h
166
+
167
+ - uses: actions/setup-python@v3
168
+ with:
169
+ python-version: ${{ matrix.python }}
170
+
171
+ - name: Setup Mamba
172
+ uses: conda-incubator/[email protected]
173
+ with:
174
+ activate-environment: "build"
175
+ python-version: ${{ matrix.python }}
176
+ mamba-version: "*"
177
+ use-mamba: false
178
+ channels: conda-forge,defaults
179
+ channel-priority: true
180
+ add-pip-as-python-dependency: true
181
+ auto-activate-base: false
182
+
183
+ - name: Set up ROCm
184
+ run: |
185
+ echo "Using python:"
186
+ python --version
187
+ which python
188
+
189
+ if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then
190
+ export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb
191
+ elif [[ "${{ matrix.rocm }}" == "5.6.1" ]]; then
192
+ export ROCM_DL_FILE=amdgpu-install_5.6.50601-1_all.deb
193
+ elif [[ "${{ matrix.rocm }}" == "5.7.1" ]]; then
194
+ export ROCM_DL_FILE=amdgpu-install_5.7.50701-1_all.deb
195
+ else
196
+ echo Unknown rocm version
197
+ exit 1
198
+ fi
199
+
200
+ curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/focal/$ROCM_DL_FILE
201
+ sudo dpkg -i $ROCM_DL_FILE
202
+ sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y
203
+
204
+ - name: Install Dependencies
205
+ run: |
206
+ sudo apt-get update
207
+ sudo apt-get install -y --no-install-recommends rocsparse-dev rocthrust-dev rocblas-dev hipblas-dev hipsparse-dev
208
+
209
+ python -m pip install --upgrade build setuptools wheel
210
+
211
+ if [[ "${{ matrix.rocm }}" == "5.7.1" ]]; then
212
+ python -m pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/rocm5.7
213
+ elif [[ "${{ matrix.rocm }}" == "5.6.1" ]]; then
214
+ python -m pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/rocm5.6
215
+ else
216
+ echo Unknown rocm version for python install
217
+ exit 1
218
+ fi
219
+
220
+ - name: Build Wheel
221
+ run: |
222
+ echo "Using python for build:"
223
+ python --version
224
+ which python
225
+
226
+ ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
227
+
228
+ - name: Upload Assets
229
+ uses: shogo82148/actions-upload-release-asset@v1
230
+ with:
231
+ upload_url: ${{ needs.release.outputs.upload_url }}
232
+ asset_path: ./dist/*.whl
AutoAWQ_kernels/.github/workflows/scripts/github_create_release.js ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module.exports = async (github, context, core) => {
2
+ try {
3
+ const response = await github.rest.repos.createRelease({
4
+ draft: false,
5
+ generate_release_notes: true,
6
+ name: process.env.RELEASE_TAG,
7
+ owner: context.repo.owner,
8
+ prerelease: false,
9
+ repo: context.repo.repo,
10
+ tag_name: process.env.RELEASE_TAG,
11
+ });
12
+
13
+ core.setOutput('upload_url', response.data.upload_url);
14
+ } catch (error) {
15
+ core.setFailed(error.message);
16
+ }
17
+ }
AutoAWQ_kernels/.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ #.idea/
162
+
163
+ *hip*
164
+ !hip_compact.hip
AutoAWQ_kernels/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Casper
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
AutoAWQ_kernels/README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AutoAWQ Kernels
2
+
3
+ AutoAWQ Kernels is a new package that is split up from the [main repository](https://github.com/casper-hansen/AutoAWQ) in order to avoid compilation times.
4
+
5
+ ## Requirements
6
+
7
+ - Windows: Must use WSL2.
8
+
9
+ - NVIDIA:
10
+ - GPU: Must be compute capability 7.5 or higher.
11
+ - CUDA Toolkit: Must be 11.8 or higher.
12
+ - AMD:
13
+ - ROCm: Must be 5.6 or higher.
14
+
15
+ ## Install
16
+
17
+ ### Install from PyPi
18
+
19
+ The package is available on PyPi with CUDA 12.1.1 wheels:
20
+
21
+ ```
22
+ pip install autoawq-kernels
23
+ ```
24
+
25
+ ### Install release wheels
26
+
27
+ For ROCm and other CUDA versions, you can use the wheels published at each [release](https://github.com/casper-hansen/AutoAWQ_kernels/releases/):
28
+
29
+ ```
30
+ pip install https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.2/autoawq_kernels-0.0.2+rocm561-cp310-cp310-linux_x86_64.whl
31
+ ```
32
+
33
+ ### Build from source
34
+ You can also build from source:
35
+
36
+ ```
37
+ git clone https://github.com/casper-hansen/AutoAWQ_kernels
38
+ cd AutoAWQ_kernels
39
+ pip install -e .
40
+ ```
41
+
42
+ To build for ROCm, you need to first install the following packages `rocsparse-dev hipsparse-dev rocthrust-dev rocblas-dev hipblas-dev`.
AutoAWQ_kernels/awq_ext/attention/cuda_bf16_fallbacks.cuh ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
3
+ /*
4
+ * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ #include "cuda_bf16_wrapper.h"
22
+ #include <cuda_fp16.h>
23
+
24
+ namespace fastertransformer {
25
+
26
+ #ifdef ENABLE_BF16
27
+ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
28
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
29
+ float2 f_val;
30
+ f_val.x = __low2float(val);
31
+ f_val.y = __high2float(val);
32
+ return f_val;
33
+ #else
34
+ return __bfloat1622float2(val);
35
+ #endif
36
+ }
37
+
38
+ inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
39
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
40
+ float2 f_val;
41
+ f_val.x = max(min(__low2float(val), 127.f), -128.f);
42
+ f_val.y = max(min(__high2float(val), 127.f), -128.f);
43
+ union { int8_t int8[2]; int16_t int16; };
44
+ int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
45
+ int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
46
+ return int16;
47
+ #else
48
+ val = __hmin2(val, make_bfloat162(127., 127.));
49
+ val = __hmax2(val, make_bfloat162(-128., -128.));
50
+ union { int8_t int8[2]; int16_t int16; };
51
+ int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
52
+ int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
53
+ return int16;
54
+ #endif
55
+ }
56
+
57
+ inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
58
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
59
+ return __floats2bfloat162_rn(val.x, val.y);
60
+ #else
61
+ return __float22bfloat162_rn(val);
62
+ #endif
63
+ }
64
+
65
+ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
66
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
67
+ __nv_bfloat162 val2;
68
+ val2.x = val;
69
+ val2.y = val;
70
+ return val2;
71
+ #else
72
+ return __bfloat162bfloat162(val);
73
+ #endif
74
+ }
75
+
76
+ inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
77
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
78
+ float fxl, fxh, fyl, fyh;
79
+ fxl = __low2float(x);
80
+ fxh = __high2float(x);
81
+ fyl = __low2float(y);
82
+ fyh = __high2float(y);
83
+ return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
84
+ #else
85
+ return __hadd2(x, y);
86
+ #endif
87
+ }
88
+
89
+ inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
90
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
91
+ return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
92
+ #else
93
+ return __hadd(x, y);
94
+ #endif
95
+ }
96
+
97
+ inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
98
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
99
+ float fxl, fxh, fyl, fyh;
100
+ fxl = __low2float(x);
101
+ fxh = __high2float(x);
102
+ fyl = __low2float(y);
103
+ fyh = __high2float(y);
104
+ return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
105
+ #else
106
+ return __hsub2(x, y);
107
+ #endif
108
+ }
109
+
110
+ inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
111
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
112
+ return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
113
+ #else
114
+ return __hsub(x, y);
115
+ #endif
116
+ }
117
+
118
+ inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
119
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
120
+ float fxl, fxh, fyl, fyh;
121
+ fxl = __low2float(x);
122
+ fxh = __high2float(x);
123
+ fyl = __low2float(y);
124
+ fyh = __high2float(y);
125
+ return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
126
+ #else
127
+ return __hmul2(x, y);
128
+ #endif
129
+ }
130
+
131
+ inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
132
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
133
+ return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
134
+ #else
135
+ return __hmul(x, y);
136
+ #endif
137
+ }
138
+
139
+ inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
140
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
141
+ float fxl, fxh, fyl, fyh, fzl, fzh;
142
+ fxl = __low2float(x);
143
+ fxh = __high2float(x);
144
+ fyl = __low2float(y);
145
+ fyh = __high2float(y);
146
+ fzl = __low2float(z);
147
+ fzh = __high2float(z);
148
+ return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
149
+ #else
150
+ return __hfma2(x, y, z);
151
+ #endif
152
+ }
153
+
154
+ inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
155
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
156
+ return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
157
+ #else
158
+ return __hfma(x, y, z);
159
+ #endif
160
+ }
161
+
162
+ inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
163
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
164
+ float fxl, fxh;
165
+ fxl = __low2float(x);
166
+ fxh = __high2float(x);;
167
+ return __floats2bfloat162_rn(expf(fxl), expf(fxh));
168
+ #else
169
+ return h2exp(x);
170
+ #endif
171
+ }
172
+
173
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
174
+ inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
175
+ inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
176
+
177
+ inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
178
+ {
179
+ __nv_bfloat162 t; t.x = x; t.y = y; return t;
180
+ }
181
+
182
+ #endif
183
+
184
+ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
185
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
186
+ return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
187
+ #else
188
+ return a + b + c;
189
+ #endif
190
+ }
191
+
192
+ inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
193
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
194
+ return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
195
+ #else
196
+ return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
197
+ #endif
198
+ }
199
+
200
+ inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
201
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
202
+ float fal, fah, fbl, fbh, fcl, fch;
203
+ fal = __low2float(a);
204
+ fah = __high2float(a);
205
+ fbl = __low2float(b);
206
+ fbh = __high2float(b);
207
+ fcl = __low2float(c);
208
+ fch = __high2float(c);
209
+ return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
210
+ #else
211
+ return a + b + c;
212
+ #endif
213
+ }
214
+
215
+ inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
216
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
217
+ return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
218
+ #else
219
+ return a * b * c;
220
+ #endif
221
+ }
222
+
223
+ inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
224
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
225
+ float fal, fah, fbl, fbh, fcl, fch;
226
+ fal = __low2float(a);
227
+ fah = __high2float(a);
228
+ fbl = __low2float(b);
229
+ fbh = __high2float(b);
230
+ fcl = __low2float(c);
231
+ fch = __high2float(c);
232
+ return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
233
+ #else
234
+ return a * b * c;
235
+ #endif
236
+ }
237
+
238
+ inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
239
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
240
+ float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
241
+ fal = __low2float(a);
242
+ fah = __high2float(a);
243
+ fbl = __low2float(b);
244
+ fbh = __high2float(b);
245
+ fcl = __low2float(c);
246
+ fch = __high2float(c);
247
+ fdl = __low2float(d);
248
+ fdh = __high2float(d);
249
+ return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
250
+ #else
251
+ return a * b * c + d;
252
+ #endif
253
+ }
254
+
255
+ #endif // ENABLE_BF16
256
+
257
+ } // namespace fastertransformer
AutoAWQ_kernels/awq_ext/attention/cuda_bf16_wrapper.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
3
+ /*
4
+ * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ #ifdef ENABLE_BF16
22
+ #include <cuda_bf16.h>
23
+ #endif
AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention.cu ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
3
+ /*
4
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #include "decoder_masked_multihead_attention.h"
20
+ #include "decoder_masked_multihead_attention_utils.h"
21
+ #include "cuda_bf16_wrapper.h"
22
+ #include <assert.h>
23
+ #include <float.h>
24
+ #include <type_traits>
25
+
26
+ #include "decoder_masked_multihead_attention_template.hpp"
27
+
28
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
29
+
30
+ #define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
31
+ size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
32
+ auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
33
+ THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
34
+ cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
35
+ dim3 grid(params.num_heads, params.batch_size); \
36
+ kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
37
+
38
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
39
+
40
+ // !!! Specialize the launcher for Cross attention
41
+ template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
42
+ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
43
+ {
44
+ constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
45
+ constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
46
+ int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
47
+ // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
48
+ if (tlength < 32) {
49
+ MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
50
+ }
51
+ else if (tlength < 2048) {
52
+ MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
53
+ }
54
+ else {
55
+ MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
56
+ }
57
+ }
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
60
+
61
+ #undef MMHA_LAUNCH_KERNEL
62
+
63
+ template<typename T, typename KERNEL_PARAMS_TYPE>
64
+ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
65
+ {
66
+ switch (params.hidden_size_per_head) {
67
+ case 32:
68
+ mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
69
+ break;
70
+ case 48:
71
+ mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
72
+ break;
73
+ case 64:
74
+ mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
75
+ break;
76
+ case 80:
77
+ mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
78
+ break;
79
+ case 96:
80
+ mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
81
+ break;
82
+ case 112:
83
+ mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
84
+ break;
85
+ case 128:
86
+ mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
87
+ break;
88
+ case 160:
89
+ mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
90
+ break;
91
+ case 192:
92
+ mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
93
+ break;
94
+ case 224:
95
+ mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
96
+ break;
97
+ case 256:
98
+ mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
99
+ break;
100
+ default:
101
+ assert(false);
102
+ }
103
+ }
104
+
105
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
106
+
107
+ void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
108
+ {
109
+ multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
110
+ }
111
+
112
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
113
+
114
+ void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
115
+ {
116
+ multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
117
+ }
118
+
119
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
120
+
121
+ #ifdef ENABLE_BF16
122
+ void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
123
+ const cudaStream_t& stream)
124
+ {
125
+ multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
126
+ }
127
+ #endif
128
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
129
+
130
+ void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
131
+ {
132
+ multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
133
+ }
134
+
135
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
136
+
137
+ void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
138
+ {
139
+ multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
140
+ }
141
+
142
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
143
+
144
+ #ifdef ENABLE_BF16
145
+ void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
146
+ const cudaStream_t& stream)
147
+ {
148
+ multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
149
+ }
150
+ #endif
151
+
152
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention.h ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
3
+ /*
4
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ #include "cuda_bf16_wrapper.h"
22
+ #include <cuda_fp16.h>
23
+ #include <cuda_runtime_api.h>
24
+ #include <stdint.h>
25
+ #include <stdio.h>
26
+ #include <stdlib.h>
27
+
28
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
29
+
30
+ #define CHECK_CUDA(call) \
31
+ do { \
32
+ cudaError_t status_ = call; \
33
+ if (status_ != cudaSuccess) { \
34
+ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
35
+ exit(1); \
36
+ } \
37
+ } while (0)
38
+
39
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ // The structure of parameters for the masked multihead attention kernel.
42
+ //
43
+ // We use the following terminology to describe the different dimensions.
44
+ //
45
+ // B: Batch size (number of sequences),
46
+ // L: Sequence length,
47
+ // D: Hidden dimension,
48
+ // H: Number of heads,
49
+ // Dh: Hidden dimension per head - Dh = D / H.
50
+
51
+ template<typename T>
52
+ struct Multihead_attention_params_base {
53
+
54
+ // The output buffer. Dimensions B x D.
55
+ T* out = nullptr;
56
+
57
+ // The input Qs and the associated bias. Dimensions B x D and D, resp.
58
+ const T *q = nullptr, *q_bias = nullptr;
59
+ // The input Ks and the associated bias. Dimensions B x D and D, resp.
60
+ const T *k = nullptr, *k_bias = nullptr;
61
+ // The input Vs and the associated bias. Dimensions B x D and D, resp.
62
+ const T *v = nullptr, *v_bias = nullptr;
63
+
64
+ // The cache for the Ks. The size must be at least B x L x D.
65
+ T* k_cache = nullptr;
66
+ // The cache for the Vs. The size must be at least B x L x D.
67
+ T* v_cache = nullptr;
68
+ // The indirections to use for cache when beam sampling.
69
+ const int* cache_indir = nullptr;
70
+
71
+ // Stride to handle the case when KQV is a single buffer
72
+ int stride = 0;
73
+
74
+ // The batch size.
75
+ int batch_size = 0;
76
+ // The beam width
77
+ int beam_width = 0;
78
+ // The sequence length.
79
+ int memory_max_len = 0;
80
+ // The number of heads (H).
81
+ int num_heads = 0;
82
+ // The number of heads for KV cache.
83
+ int num_kv_heads = 0;
84
+ // The hidden dimension per head (Dh).
85
+ int hidden_size_per_head = 0;
86
+ // The per-head latent space reserved for rotary embeddings.
87
+ int rotary_embedding_dim = 0;
88
+ bool neox_rotary_style = false;
89
+ float rotary_base = 0.0f;
90
+ // The maximum length of input sentences.
91
+ int max_input_length = 0;
92
+ // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
93
+ int timestep = 0;
94
+ // The current timestep of each sentences (support different timestep for different sentences)
95
+
96
+ // The 1.f / sqrt(Dh). Computed on the host.
97
+ float inv_sqrt_dh = 0.0f;
98
+
99
+ // Used when we have some input context like gpt
100
+ const int* total_padding_tokens = nullptr;
101
+
102
+ const bool* masked_tokens = nullptr;
103
+ const int* prefix_prompt_lengths = nullptr;
104
+ int max_prefix_prompt_length = 0;
105
+
106
+ const T* relative_attention_bias = nullptr;
107
+ int relative_attention_bias_stride = 0;
108
+ // The slope per head of linear position bias to attention score (H).
109
+ const float* linear_bias_slopes = nullptr;
110
+
111
+ const T* ia3_key_weights = nullptr;
112
+ const T* ia3_value_weights = nullptr;
113
+ const int* ia3_tasks = nullptr;
114
+
115
+ const float* qkv_scale_out = nullptr;
116
+ const float* attention_out_scale = nullptr;
117
+ int int8_mode = 0;
118
+ };
119
+
120
+ template<typename T, bool CROSS_ATTENTION>
121
+ struct Multihead_attention_params: public Multihead_attention_params_base<T> {
122
+ // output cross attentions
123
+ float* cross_attention_out = nullptr;
124
+ int max_decoder_seq_len = 0;
125
+ bool is_return_cross_attentions = false;
126
+
127
+ // allows to exist attention eary
128
+ bool* finished = nullptr;
129
+
130
+ // required in case of cross attention
131
+ // will need it here till if constexpr in c++17
132
+ int* memory_length_per_sample = nullptr;
133
+
134
+ // required in case of masked attention with different length
135
+ const int* length_per_sample = nullptr;
136
+ };
137
+
138
+ template<typename T>
139
+ struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
140
+ // output cross attentions
141
+ float* cross_attention_out = nullptr;
142
+ int max_decoder_seq_len = 0;
143
+ bool is_return_cross_attentions = false;
144
+
145
+ // allows to exist attention eary
146
+ bool* finished = nullptr;
147
+
148
+ // required in case of cross attention
149
+ int* memory_length_per_sample = nullptr;
150
+
151
+ // required in case of masked attention with different length
152
+ const int* length_per_sample = nullptr;
153
+ };
154
+
155
+ template<class T>
156
+ using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
157
+
158
+ template<class T>
159
+ using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
160
+
161
+ template<typename T>
162
+ struct outputCrossAttentionParam {
163
+ // max decoder output length
164
+ int max_decoder_seq_len = 0;
165
+ T* cross_attention_out = nullptr;
166
+ bool is_return_cross_attentions = false;
167
+ };
168
+
169
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
170
+
171
+ void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
172
+ void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
173
+ #ifdef ENABLE_BF16
174
+ void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
175
+ const cudaStream_t& stream);
176
+ #endif
177
+ void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
178
+ void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
179
+ #ifdef ENABLE_BF16
180
+ void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
181
+ const cudaStream_t& stream);
182
+ #endif
183
+
184
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention_template.hpp ADDED
@@ -0,0 +1,1608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
+ /*
4
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+ #pragma once
19
+
20
+ #include "decoder_masked_multihead_attention.h"
21
+ #include "decoder_masked_multihead_attention_utils.h"
22
+ #include "cuda_bf16_wrapper.h"
23
+ #include "cuda_bf16_fallbacks.cuh"
24
+ #include <assert.h>
25
+ #include <float.h>
26
+ #include <type_traits>
27
+
28
+ // #define MMHA_USE_HMMA_FOR_REDUCTION
29
+
30
+ // Below are knobs to extend FP32 accumulation for higher FP16 accuracy
31
+
32
+ // Does not seem to affect the accuracy that much
33
+ #define MMHA_USE_FP32_ACUM_FOR_FMA
34
+
35
+ // Seems to slightly improve the accuracy
36
+ #define MMHA_USE_FP32_ACUM_FOR_OUT
37
+
38
+ #if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
39
+ // Does not seem to improve the accuracy
40
+ //#define MMHA_USE_FP32_ACUM_FOR_LOGITS
41
+ #endif
42
+
43
+ namespace mmha {
44
+
45
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ //
48
+ // We use the following terminology to describe the different dimensions.
49
+ //
50
+ // B: Batch size (number of sequences),
51
+ // L: Sequence length,
52
+ // D: Hidden dimension,
53
+ // H: Number of heads,
54
+ // Dh: Hidden dimension per head - Dh = D / H.
55
+ //
56
+ // The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
57
+ // 64, 128 and 256 threads per block.
58
+ //
59
+ // Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
60
+ // compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
61
+ // cache buffer helps with memory accesses and contains keys with bias.
62
+ //
63
+ // The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
64
+ // x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
65
+ // values for x are chosen to create chunks of 16 bytes.
66
+ //
67
+ // The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
68
+ // depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
69
+ // the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
70
+ // HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
71
+ //
72
+ // After that loop, a parallel softmax is computed across the different Q * K^T values stored in
73
+ // shared memory.
74
+ //
75
+ // The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
76
+ // timesteps are computed by loop iteration. As with the keys, the values are read from a cache
77
+ // except for the current timestep. The layout of the cache buffer for the values is much simpler
78
+ // as it is [B, H, L, Dh].
79
+ //
80
+
81
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
82
+
83
+ template<typename T, int Dh>
84
+ struct Qk_vec_ {
85
+ };
86
+
87
+ template<>
88
+ struct Qk_vec_<float, 32> {
89
+ using Type = float;
90
+ };
91
+ template<>
92
+ struct Qk_vec_<float, 64> {
93
+ using Type = float2;
94
+ };
95
+ template<>
96
+ struct Qk_vec_<float, 128> {
97
+ using Type = float4;
98
+ };
99
+ template<>
100
+ struct Qk_vec_<float, 256> {
101
+ using Type = float4;
102
+ };
103
+ template<>
104
+ struct Qk_vec_<uint16_t, 32> {
105
+ using Type = uint32_t;
106
+ };
107
+ template<>
108
+ struct Qk_vec_<uint16_t, 64> {
109
+ using Type = uint32_t;
110
+ };
111
+ template<>
112
+ struct Qk_vec_<uint16_t, 128> {
113
+ using Type = uint2;
114
+ };
115
+ template<>
116
+ struct Qk_vec_<uint16_t, 256> {
117
+ using Type = uint4;
118
+ };
119
+ #ifdef ENABLE_BF16
120
+ template<>
121
+ struct Qk_vec_<__nv_bfloat16, 32> {
122
+ using Type = __nv_bfloat162;
123
+ };
124
+ template<>
125
+ struct Qk_vec_<__nv_bfloat16, 64> {
126
+ using Type = __nv_bfloat162;
127
+ };
128
+ template<>
129
+ struct Qk_vec_<__nv_bfloat16, 128> {
130
+ using Type = bf16_4_t;
131
+ };
132
+ template<>
133
+ struct Qk_vec_<__nv_bfloat16, 256> {
134
+ using Type = bf16_8_t;
135
+ };
136
+ #endif // ENABLE_BF16
137
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
138
+
139
+ template<typename T, int THREADS_PER_KEY>
140
+ struct K_vec_ {
141
+ };
142
+
143
+ template<>
144
+ struct K_vec_<float, 4> {
145
+ using Type = float;
146
+ };
147
+ template<>
148
+ struct K_vec_<float, 2> {
149
+ using Type = float2;
150
+ };
151
+ template<>
152
+ struct K_vec_<float, 1> {
153
+ using Type = float4;
154
+ };
155
+ template<>
156
+ struct K_vec_<uint16_t, 4> {
157
+ using Type = uint32_t;
158
+ };
159
+ template<>
160
+ struct K_vec_<uint16_t, 2> {
161
+ using Type = uint2;
162
+ };
163
+ template<>
164
+ struct K_vec_<uint16_t, 1> {
165
+ using Type = uint4;
166
+ };
167
+ #ifdef ENABLE_BF16
168
+ template<>
169
+ struct K_vec_<__nv_bfloat16, 4> {
170
+ using Type = __nv_bfloat162;
171
+ };
172
+ template<>
173
+ struct K_vec_<__nv_bfloat16, 2> {
174
+ using Type = bf16_4_t;
175
+ };
176
+ template<>
177
+ struct K_vec_<__nv_bfloat16, 1> {
178
+ using Type = bf16_8_t;
179
+ };
180
+ #endif // ENABLE_BF16
181
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
182
+
183
+ template<typename T, int V_VEC_SIZE>
184
+ struct V_vec_ {
185
+ };
186
+
187
+ template<>
188
+ struct V_vec_<float, 1> {
189
+ using Type = float;
190
+ };
191
+ template<>
192
+ struct V_vec_<float, 2> {
193
+ using Type = float2;
194
+ };
195
+ template<>
196
+ struct V_vec_<float, 4> {
197
+ using Type = float4;
198
+ };
199
+ template<>
200
+ struct V_vec_<uint16_t, 2> {
201
+ using Type = uint32_t;
202
+ };
203
+ template<>
204
+ struct V_vec_<uint16_t, 4> {
205
+ using Type = uint2;
206
+ };
207
+ template<>
208
+ struct V_vec_<uint16_t, 8> {
209
+ using Type = uint4;
210
+ };
211
+ #ifdef ENABLE_BF16
212
+ template<>
213
+ struct V_vec_<__nv_bfloat16, 2> {
214
+ using Type = __nv_bfloat162;
215
+ };
216
+ template<>
217
+ struct V_vec_<__nv_bfloat16, 4> {
218
+ using Type = bf16_4_t;
219
+ };
220
+ template<>
221
+ struct V_vec_<__nv_bfloat16, 8> {
222
+ using Type = bf16_8_t;
223
+ };
224
+ #endif // ENABLE_BF16
225
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
226
+
227
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
228
+ template<typename T>
229
+ struct Qk_vec_acum_fp32_ {
230
+ };
231
+
232
+ template<>
233
+ struct Qk_vec_acum_fp32_<float> {
234
+ using Type = float;
235
+ };
236
+ template<>
237
+ struct Qk_vec_acum_fp32_<float2> {
238
+ using Type = float2;
239
+ };
240
+ template<>
241
+ struct Qk_vec_acum_fp32_<float4> {
242
+ using Type = float4;
243
+ };
244
+ // template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
245
+ template<>
246
+ struct Qk_vec_acum_fp32_<uint32_t> {
247
+ using Type = float2;
248
+ };
249
+ template<>
250
+ struct Qk_vec_acum_fp32_<uint2> {
251
+ using Type = Float4_;
252
+ };
253
+ template<>
254
+ struct Qk_vec_acum_fp32_<uint4> {
255
+ using Type = Float8_;
256
+ };
257
+ template<>
258
+ struct Qk_vec_acum_fp32_<__nv_bfloat16> {
259
+ using Type = float;
260
+ };
261
+ template<>
262
+ struct Qk_vec_acum_fp32_<__nv_bfloat162> {
263
+ using Type = float2;
264
+ };
265
+ template<>
266
+ struct Qk_vec_acum_fp32_<bf16_4_t> {
267
+ using Type = Float4_;
268
+ };
269
+ template<>
270
+ struct Qk_vec_acum_fp32_<bf16_8_t> {
271
+ using Type = Float8_;
272
+ };
273
+
274
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
275
+
276
+ template<typename T>
277
+ struct K_vec_acum_fp32_ {
278
+ };
279
+
280
+ template<>
281
+ struct K_vec_acum_fp32_<float> {
282
+ using Type = float;
283
+ };
284
+ template<>
285
+ struct K_vec_acum_fp32_<float2> {
286
+ using Type = float2;
287
+ };
288
+ template<>
289
+ struct K_vec_acum_fp32_<float4> {
290
+ using Type = float4;
291
+ };
292
+ template<>
293
+ struct K_vec_acum_fp32_<uint32_t> {
294
+ using Type = float2;
295
+ };
296
+ template<>
297
+ struct K_vec_acum_fp32_<uint2> {
298
+ using Type = Float4_;
299
+ };
300
+ template<>
301
+ struct K_vec_acum_fp32_<uint4> {
302
+ using Type = Float8_;
303
+ };
304
+ template<>
305
+ struct K_vec_acum_fp32_<__nv_bfloat16> {
306
+ using Type = float;
307
+ };
308
+ template<>
309
+ struct K_vec_acum_fp32_<__nv_bfloat162> {
310
+ using Type = float2;
311
+ };
312
+ template<>
313
+ struct K_vec_acum_fp32_<bf16_4_t> {
314
+ using Type = Float4_;
315
+ };
316
+ template<>
317
+ struct K_vec_acum_fp32_<bf16_8_t> {
318
+ using Type = Float8_;
319
+ };
320
+ #endif
321
+
322
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
323
+
324
+ #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
325
+ template<typename T>
326
+ struct V_vec_acum_fp32_ {
327
+ };
328
+
329
+ template<>
330
+ struct V_vec_acum_fp32_<float> {
331
+ using Type = float;
332
+ };
333
+ template<>
334
+ struct V_vec_acum_fp32_<float2> {
335
+ using Type = float2;
336
+ };
337
+ template<>
338
+ struct V_vec_acum_fp32_<float4> {
339
+ using Type = float4;
340
+ };
341
+ template<>
342
+ struct V_vec_acum_fp32_<uint32_t> {
343
+ using Type = float2;
344
+ };
345
+ template<>
346
+ struct V_vec_acum_fp32_<uint2> {
347
+ using Type = Float4_;
348
+ };
349
+ template<>
350
+ struct V_vec_acum_fp32_<uint4> {
351
+ using Type = Float8_;
352
+ };
353
+ #ifdef ENABLE_BF16
354
+ template<>
355
+ struct V_vec_acum_fp32_<__nv_bfloat162> {
356
+ using Type = float2;
357
+ };
358
+ template<>
359
+ struct V_vec_acum_fp32_<bf16_4_t> {
360
+ using Type = Float4_;
361
+ };
362
+ template<>
363
+ struct V_vec_acum_fp32_<bf16_8_t> {
364
+ using Type = Float8_;
365
+ };
366
+ #endif // ENABLE_BF16
367
+ #endif
368
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
369
+
370
+ template<int THREADS_PER_KEY, typename K_vec, int N>
371
+ inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
372
+ {
373
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
374
+ using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
375
+ #else
376
+ using K_vec_acum = K_vec;
377
+ #endif
378
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
379
+ K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
380
+ #pragma unroll
381
+ for (int ii = 1; ii < N; ++ii) {
382
+ qk_vec = fma(q[ii], k[ii], qk_vec);
383
+ }
384
+
385
+ // Finalize the reduction across lanes.
386
+ float qk = sum(qk_vec);
387
+ #pragma unroll
388
+ for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
389
+ qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
390
+ }
391
+ return qk;
392
+ }
393
+
394
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
395
+
396
+ template<typename T, int THREADS_PER_KEY>
397
+ struct Qk_dot {
398
+ template<typename K_vec, int N>
399
+ static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
400
+ {
401
+ return qk_dot_<THREADS_PER_KEY>(q, k);
402
+ }
403
+ };
404
+
405
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
406
+
407
+ inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
408
+ {
409
+ float4 c;
410
+ float zero = 0.f;
411
+ asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
412
+ " {%0, %1, %2, %3}, \n"
413
+ " {%4, %5}, \n"
414
+ " {%6}, \n"
415
+ " {%7, %7, %7, %7}; \n"
416
+
417
+ : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
418
+ : "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
419
+ return c;
420
+ }
421
+
422
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
423
+
424
+ template<int N>
425
+ inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
426
+ {
427
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
428
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
429
+ using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
430
+ #else
431
+ using K_vec_acum = uint32_t;
432
+ #endif
433
+ K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
434
+ #pragma unroll
435
+ for (int ii = 1; ii < N; ++ii) {
436
+ qk_vec = fma(q[ii], k[ii], qk_vec);
437
+ }
438
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
439
+ uint32_t qk_vec_ = float2_to_half2(qk_vec);
440
+ return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
441
+ #else
442
+ return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
443
+ #endif
444
+ #else
445
+ return 0.f;
446
+ #endif
447
+ }
448
+
449
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
450
+
451
+ template<>
452
+ struct Qk_dot<uint16_t, 4> {
453
+ template<int N>
454
+ static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
455
+ {
456
+ #if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
457
+ return qk_hmma_dot_(q, k);
458
+ #else
459
+ return qk_dot_<4>(q, k);
460
+ #endif // defined MMHA_USE_HMMA_FOR_REDUCTION
461
+ }
462
+ };
463
+
464
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
465
+
466
+ template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
467
+ inline __device__ float block_sum(float* red_smem, float sum)
468
+ {
469
+
470
+ // Decompose the thread index into warp / lane.
471
+ int warp = threadIdx.x / WARP_SIZE;
472
+ int lane = threadIdx.x % WARP_SIZE;
473
+
474
+ // Compute the sum per warp.
475
+ #pragma unroll
476
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
477
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
478
+ }
479
+
480
+ // Warp leaders store the data to shared memory.
481
+ if (lane == 0) {
482
+ red_smem[warp] = sum;
483
+ }
484
+
485
+ // Make sure the data is in shared memory.
486
+ __syncthreads();
487
+
488
+ // The warps compute the final sums.
489
+ if (lane < WARPS_PER_BLOCK) {
490
+ sum = red_smem[lane];
491
+ }
492
+
493
+ // Parallel reduction inside the warp.
494
+ #pragma unroll
495
+ for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
496
+ sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
497
+ }
498
+
499
+ // Broadcast to other threads.
500
+ return __shfl_sync(uint32_t(-1), sum, 0);
501
+ }
502
+
503
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
504
+
505
+ inline __device__ void convert_from_float(float& dst, float src)
506
+ {
507
+ dst = src;
508
+ }
509
+
510
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
511
+
512
+ inline __device__ void convert_from_float(uint16_t& dst, float src)
513
+ {
514
+ dst = float_to_half(src);
515
+ }
516
+
517
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
518
+
519
+ inline __device__ void convert_from_float(uint32_t& dst, float2 src)
520
+ {
521
+ dst = float2_to_half2(src);
522
+ }
523
+
524
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
525
+ #ifdef ENABLE_BF16
526
+ inline __device__ void convert_from_float(__nv_bfloat16& dst, float src)
527
+ {
528
+ dst = __float2bfloat16(src);
529
+ }
530
+
531
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
532
+
533
+ inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src)
534
+ {
535
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
536
+ dst = __float22bfloat162_rn(src);
537
+ #else
538
+ dst = __floats2bfloat162_rn(src.x, src.y);
539
+ #endif
540
+ }
541
+ #endif // ENABLE_BF16
542
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
543
+
544
+ inline __device__ void convert_from_float(uint2& dst, Float4_ src)
545
+ {
546
+ dst.x = float2_to_half2(src.x);
547
+ dst.y = float2_to_half2(src.y);
548
+ }
549
+
550
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
551
+
552
+ inline __device__ void convert_from_float(uint2& dst, float4 src)
553
+ {
554
+ convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
555
+ }
556
+
557
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
558
+
559
+ inline __device__ void convert_from_float(uint4& dst, Float8_ src)
560
+ {
561
+ dst.x = float2_to_half2(src.x);
562
+ dst.y = float2_to_half2(src.y);
563
+ dst.z = float2_to_half2(src.z);
564
+ dst.w = float2_to_half2(src.w);
565
+ }
566
+
567
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
568
+
569
+ #ifdef ENABLE_BF16
570
+ inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src)
571
+ {
572
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
573
+ dst.x = __float22bfloat162_rn(src.x);
574
+ dst.y = __float22bfloat162_rn(src.y);
575
+ #else
576
+ dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
577
+ dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
578
+ #endif
579
+ }
580
+
581
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
582
+
583
+ inline __device__ void convert_from_float(bf16_4_t& dst, float4 src)
584
+ {
585
+ convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
586
+ }
587
+
588
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
589
+
590
+ inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
591
+ {
592
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
593
+ dst.x = __float22bfloat162_rn(src.x);
594
+ dst.y = __float22bfloat162_rn(src.y);
595
+ dst.z = __float22bfloat162_rn(src.z);
596
+ dst.w = __float22bfloat162_rn(src.w);
597
+ #else
598
+ dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
599
+ dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
600
+ dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
601
+ dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
602
+ #endif
603
+ }
604
+ #endif // ENABLE_BF16
605
+
606
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
607
+
608
+ inline __device__ void convert_from_float(float2& dst, float2 src)
609
+ {
610
+ dst = src;
611
+ }
612
+
613
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
614
+
615
+ inline __device__ void convert_from_float(float4& dst, float4 src)
616
+ {
617
+ dst = src;
618
+ }
619
+
620
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
621
+
622
+ inline __device__ float convert_to_float(float4 u)
623
+ {
624
+ return u.x;
625
+ }
626
+
627
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
628
+
629
+ inline __device__ float convert_to_float(uint4 u)
630
+ {
631
+ float2 tmp = half2_to_float2(u.x);
632
+ return tmp.x;
633
+ }
634
+
635
+ #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
636
+
637
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
638
+
639
+ inline __device__ float cast_to_float(float u)
640
+ {
641
+ return u;
642
+ }
643
+
644
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
645
+
646
+ inline __device__ float2 cast_to_float(float2 u)
647
+ {
648
+ return u;
649
+ }
650
+
651
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
652
+
653
+ inline __device__ float4 cast_to_float(float4 u)
654
+ {
655
+ return u;
656
+ }
657
+
658
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
659
+
660
+ inline __device__ Float4_ cast_to_float(Float4_ u)
661
+ {
662
+ return u;
663
+ }
664
+
665
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
666
+
667
+ inline __device__ Float8_ cast_to_float(Float8_ u)
668
+ {
669
+ return u;
670
+ }
671
+
672
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
673
+
674
+ inline __device__ float2 cast_to_float(uint32_t u)
675
+ {
676
+ return half2_to_float2(u);
677
+ }
678
+
679
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
680
+
681
+ inline __device__ Float4_ cast_to_float(uint2 u)
682
+ {
683
+ Float4_ tmp;
684
+ tmp.x = half2_to_float2(u.x);
685
+ tmp.y = half2_to_float2(u.y);
686
+ return tmp;
687
+ }
688
+
689
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
690
+
691
+ inline __device__ Float8_ cast_to_float(uint4 u)
692
+ {
693
+ Float8_ tmp;
694
+ tmp.x = half2_to_float2(u.x);
695
+ tmp.y = half2_to_float2(u.y);
696
+ tmp.z = half2_to_float2(u.z);
697
+ tmp.w = half2_to_float2(u.w);
698
+ return tmp;
699
+ }
700
+
701
+ #endif
702
+
703
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
704
+
705
+ inline __device__ float float_from_int8(int8_t u)
706
+ {
707
+ return u;
708
+ }
709
+
710
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
711
+
712
+ inline __device__ float2 float_from_int8(int16_t u)
713
+ {
714
+ union {
715
+ int16_t int16;
716
+ int8_t int8[2];
717
+ };
718
+ int16 = u;
719
+ return make_float2(int8[0], int8[1]);
720
+ }
721
+
722
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
723
+
724
+ inline __device__ float4 float_from_int8(int32_t u)
725
+ {
726
+ union {
727
+ int32_t int32;
728
+ int8_t int8[4];
729
+ };
730
+ int32 = u;
731
+ return make_float4(int8[0], int8[1], int8[2], int8[3]);
732
+ }
733
+
734
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
735
+
736
+ // clang-format off
737
+ inline __device__ Float8_ float_from_int8(int64_t u)
738
+ {
739
+ union {
740
+ int64_t int64;
741
+ int16_t int16[4];
742
+ };
743
+ int64 = u;
744
+ return Float8_ {float_from_int8(int16[0]),
745
+ float_from_int8(int16[1]),
746
+ float_from_int8(int16[2]),
747
+ float_from_int8(int16[3])};
748
+ }
749
+ // clang-format on
750
+
751
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
752
+
753
+ inline __device__ int8_t cast_to_int8(float val)
754
+ {
755
+ union {
756
+ int8_t int8[2];
757
+ int16_t int16;
758
+ };
759
+ asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
760
+ return int8[0];
761
+ }
762
+
763
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
764
+
765
+ inline __device__ int32_t cast_to_int8(float4 val)
766
+ {
767
+ union {
768
+ int8_t int8[4];
769
+ int32_t int32;
770
+ };
771
+ int8[0] = cast_to_int8(val.x);
772
+ int8[1] = cast_to_int8(val.y);
773
+ int8[2] = cast_to_int8(val.z);
774
+ int8[3] = cast_to_int8(val.w);
775
+ return int32;
776
+ }
777
+
778
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
779
+
780
+ inline __device__ int64_t cast_to_int8(Float8_ val)
781
+ {
782
+ union {
783
+ int8_t int8[8];
784
+ int64_t int64;
785
+ };
786
+ int8[0] = cast_to_int8(val.x.x);
787
+ int8[1] = cast_to_int8(val.x.y);
788
+ int8[2] = cast_to_int8(val.y.x);
789
+ int8[3] = cast_to_int8(val.y.y);
790
+ int8[4] = cast_to_int8(val.z.x);
791
+ int8[5] = cast_to_int8(val.z.y);
792
+ int8[6] = cast_to_int8(val.w.x);
793
+ int8[7] = cast_to_int8(val.w.y);
794
+ return int64;
795
+ }
796
+
797
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
798
+
799
+ template<typename T>
800
+ inline __device__ __host__ T div_up(T m, T n)
801
+ {
802
+ return (m + n - 1) / n;
803
+ }
804
+
805
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
806
+
807
+ template<typename T, bool DO_CROSS_ATTENTION>
808
+ inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
809
+ int threads_per_value,
810
+ int threads_per_block)
811
+ {
812
+ // The amount of shared memory needed to store the Q*K^T values in float.
813
+ const int max_timesteps = min(params.timestep, params.memory_max_len);
814
+ size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
815
+
816
+ // The extra memory needed if we are not using floats for the final logits.
817
+ size_t logits_sz = 0;
818
+ #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
819
+ if (sizeof(T) != 4) {
820
+ // TDOD
821
+ logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) :
822
+ div_up(max_timesteps + 1, 4) * 4 * sizeof(T);
823
+ }
824
+ #endif
825
+
826
+ // The total size needed during softmax.
827
+ size_t softmax_sz = qk_sz + logits_sz;
828
+
829
+ // The number of partial rows to reduce in the final reduction.
830
+ int rows_per_red = threads_per_block / threads_per_value;
831
+ // The amount of storage needed to finalize the outputs.
832
+ size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2;
833
+
834
+ size_t transpose_rotary_size = 0;
835
+ if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
836
+ transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T);
837
+ }
838
+
839
+ // The max.
840
+ return max(max(softmax_sz, red_sz), transpose_rotary_size);
841
+ }
842
+
843
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
844
+
845
+ inline __device__ constexpr uint32_t shfl_mask(int threads)
846
+ {
847
+ return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
848
+ }
849
+
850
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
851
+
852
+ template<
853
+ // The type of the inputs. Supported types: float and half.
854
+ typename T,
855
+ // The hidden dimension per head.
856
+ int Dh,
857
+ int Dh_MAX,
858
+ // The number of threads per key.
859
+ int THREADS_PER_KEY,
860
+ // The number of threads per value.
861
+ int THREADS_PER_VALUE,
862
+ // The number of threads in a threadblock.
863
+ int THREADS_PER_BLOCK,
864
+ bool DO_CROSS_ATTENTION>
865
+ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, DO_CROSS_ATTENTION> params)
866
+ {
867
+
868
+ // Make sure the hidden dimension per head is a multiple of the number of threads per key.
869
+ static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
870
+ // Make sure the hidden dimension per head is a multiple of the number of threads per value.
871
+ static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
872
+
873
+ // The size of a warp.
874
+ constexpr int WARP_SIZE = 32;
875
+ // The number of warps in a threadblock.
876
+ constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
877
+
878
+ // Use smem_size_in_bytes (above) to determine the amount of shared memory.
879
+ extern __shared__ char smem_[];
880
+
881
+ // The shared memory for the Q*K^T values and partial logits in softmax.
882
+ float* qk_smem = reinterpret_cast<float*>(smem_);
883
+
884
+ // The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
885
+ char* logits_smem_ = smem_;
886
+ #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
887
+ if (sizeof(T) != 4) {
888
+ // TODO - change to tlength
889
+ const int max_timesteps = min(params.timestep, params.memory_max_len);
890
+ logits_smem_ +=
891
+ (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
892
+ }
893
+ T* logits_smem = reinterpret_cast<T*>(logits_smem_);
894
+ #else
895
+ float* logits_smem = reinterpret_cast<float*>(logits_smem_);
896
+ #endif
897
+
898
+ // The shared memory to do the final reduction for the output values. Reuse qk_smem.
899
+ T* out_smem = reinterpret_cast<T*>(smem_);
900
+
901
+ // The shared memory buffers for the block-wide reductions. One for max, one for sum.
902
+ __shared__ float red_smem[WARPS_PER_BLOCK * 2];
903
+
904
+ // A vector of Q or K elements for the current timestep.
905
+ using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
906
+
907
+ // Use alignment for safely casting the shared buffers as Qk_vec.
908
+ // Shared memory to store Q inputs.
909
+ __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
910
+
911
+ // This is one of the reasons we should have a separate kernel for cross attention
912
+ __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
913
+
914
+ // A vector of Q or K elements for the current timestep.
915
+ using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
916
+ // The number of elements per vector.
917
+ constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
918
+ // Make sure the hidden size per head is a multiple of the vector size.
919
+ static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
920
+ // We will use block wide reduction if needed
921
+ // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
922
+ // The number of vectors per warp.
923
+ constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
924
+
925
+ // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
926
+ // owns x elements, we have to decompose the linear index into chunks of x values and the posi-
927
+ // tion of the thread in that chunk.
928
+
929
+ // The number of elements in a chunk of 16B (that's the x in the above formula).
930
+ constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
931
+ // The number of K vectors in 16B.
932
+ constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
933
+
934
+ // The batch/beam idx
935
+ const int bi = blockIdx.y;
936
+ if (params.finished != nullptr && params.finished[bi] == true) {
937
+ return;
938
+ }
939
+ // The beam idx
940
+ const int beami = bi % params.beam_width;
941
+ // The "beam-aware" batch idx
942
+ const int bbi = bi / params.beam_width;
943
+ // The head.
944
+ const int num_kv_heads = params.num_kv_heads;
945
+ const int kv_rep = (params.num_heads / num_kv_heads);
946
+ const int hi = blockIdx.x;
947
+ const int hi_kv = hi / kv_rep;
948
+
949
+ // Combine the batch and the head indices.
950
+ const int bhi = bi * params.num_heads + hi;
951
+ const int bhi_kv = bi * (params.num_heads / kv_rep) + hi_kv;
952
+ // Combine the "beam-aware" batch idx and the head indices.
953
+ const int bbhi = bbi * params.beam_width * params.num_heads + hi;
954
+ const int bbhi_kv = bbi * params.beam_width * (params.num_heads / kv_rep) + hi_kv;
955
+ // The thread in the block.
956
+ const int tidx = threadIdx.x;
957
+
958
+ const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);
959
+ // Every kv_rep threads have the same kv_cache values. So only the first one writes back.
960
+ const int write_kv_cache = handle_kv && (hi % kv_rep == 0);
961
+
962
+ // While doing the product Q*K^T for the different keys we track the max.
963
+ float qk_max = -FLT_MAX;
964
+
965
+ float qk = 0.0F;
966
+
967
+ // int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
968
+ const int q_base_offset = bi * params.stride + hi * Dh;
969
+ const int k_base_offset = bi * params.stride + hi_kv * Dh;
970
+ const int v_base_offset = k_base_offset;
971
+
972
+ const size_t bi_seq_len_offset = bi * params.memory_max_len;
973
+
974
+ // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
975
+ int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
976
+ (params.length_per_sample == nullptr) ?
977
+ params.timestep :
978
+ params.length_per_sample[bi] + params.max_prefix_prompt_length;
979
+ const int first_step = max(0, tlength + 1 - params.memory_max_len);
980
+ const int tlength_circ = tlength % params.memory_max_len;
981
+
982
+ // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
983
+ const bool is_masked = tidx >= QK_VECS_PER_WARP;
984
+
985
+ // The offset in the Q and K buffer also accounts for the batch.
986
+ // int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
987
+ int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
988
+ int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
989
+ int v_offset = k_offset;
990
+
991
+ // The offset in the bias buffer.
992
+ // int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
993
+ int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
994
+ int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
995
+ int v_bias_offset = k_bias_offset;
996
+
997
+ const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
998
+ const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
999
+
1000
+ // Trigger the loads from the Q and K buffers.
1001
+ Qk_vec q;
1002
+ zero(q);
1003
+ if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
1004
+ if (params.int8_mode == 2) {
1005
+ using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
1006
+ using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
1007
+ const auto q_scaling = params.qkv_scale_out[0];
1008
+ const auto q_quant =
1009
+ *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
1010
+
1011
+ convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
1012
+ }
1013
+ else {
1014
+ q = *reinterpret_cast<const Qk_vec*>(&params.q[q_offset]);
1015
+ }
1016
+ }
1017
+
1018
+ Qk_vec k;
1019
+ zero(k);
1020
+ if (DO_CROSS_ATTENTION) {
1021
+ // The 16B chunk written by the thread.
1022
+ int co = tidx / QK_VECS_IN_16B;
1023
+ // The position of the thread in that 16B chunk.
1024
+ int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
1025
+
1026
+ // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
1027
+ int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
1028
+ // params.timestep*QK_ELTS_IN_16B +
1029
+ tlength * QK_ELTS_IN_16B + ci;
1030
+ k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
1031
+ *reinterpret_cast<const Qk_vec*>(&params.k_cache[offset]) :
1032
+ k;
1033
+ }
1034
+ else {
1035
+ if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
1036
+ if (params.int8_mode == 2) {
1037
+ using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
1038
+ using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
1039
+ const auto k_scaling = params.qkv_scale_out[1];
1040
+ const auto k_quant =
1041
+ *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
1042
+
1043
+ convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
1044
+ }
1045
+ else {
1046
+ k = *reinterpret_cast<const Qk_vec*>(&params.k[k_offset]);
1047
+ }
1048
+ }
1049
+ }
1050
+
1051
+ // Trigger the loads from the Q and K bias buffers.
1052
+ Qk_vec q_bias;
1053
+ zero(q_bias);
1054
+ q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
1055
+ *reinterpret_cast<const Qk_vec*>(&params.q_bias[q_bias_offset]) :
1056
+ q_bias;
1057
+
1058
+ Qk_vec k_bias;
1059
+ zero(k_bias);
1060
+ if (handle_kv) {
1061
+ k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
1062
+ *reinterpret_cast<const Qk_vec*>(&params.k_bias[k_bias_offset]) :
1063
+ k_bias;
1064
+ }
1065
+
1066
+ // Computes the Q/K values with bias.
1067
+ q = add(q, q_bias);
1068
+ if (handle_kv) {
1069
+ k = add(k, k_bias);
1070
+ }
1071
+ if (do_ia3 && !is_masked) {
1072
+ k = mul<Qk_vec, Qk_vec, Qk_vec>(
1073
+ k,
1074
+ *reinterpret_cast<const Qk_vec*>(
1075
+ &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]));
1076
+ }
1077
+
1078
+ // Padded len
1079
+ const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
1080
+ if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
1081
+ if (handle_kv) {
1082
+ apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
1083
+ }
1084
+ else {
1085
+ apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
1086
+ }
1087
+ }
1088
+ else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
1089
+ const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
1090
+
1091
+ T* q_smem = reinterpret_cast<T*>(smem_);
1092
+ T* k_smem = q_smem + params.rotary_embedding_dim;
1093
+
1094
+ const int half_rotary_dim = params.rotary_embedding_dim / 2;
1095
+ const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim;
1096
+ const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim;
1097
+ const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
1098
+
1099
+ assert(half_rotary_dim % QK_VEC_SIZE == 0);
1100
+
1101
+ if (do_rotary) {
1102
+ *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
1103
+
1104
+ if (handle_kv) {
1105
+ *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
1106
+ }
1107
+ }
1108
+
1109
+ __syncthreads();
1110
+
1111
+ const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
1112
+ constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
1113
+ if (do_rotary) {
1114
+ mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
1115
+
1116
+ if (handle_kv) {
1117
+ mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
1118
+
1119
+ mmha::apply_rotary_embedding(
1120
+ q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
1121
+
1122
+ mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
1123
+ }
1124
+ else {
1125
+ mmha::apply_rotary_embedding(
1126
+ q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
1127
+ }
1128
+ mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
1129
+ }
1130
+
1131
+ __syncthreads();
1132
+
1133
+ if (do_rotary) {
1134
+ q = *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx);
1135
+ if (handle_kv) {
1136
+ k = *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx);
1137
+ }
1138
+ }
1139
+
1140
+ __syncthreads();
1141
+ }
1142
+
1143
+ if (!is_masked) {
1144
+ // Store the Q values to shared memory.
1145
+ *reinterpret_cast<Qk_vec*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
1146
+
1147
+ // Store Dh values of k_bias into smem, since will need to add later
1148
+ // if params.timestep == 0
1149
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
1150
+ *reinterpret_cast<Qk_vec*>(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
1151
+ }
1152
+
1153
+ // Write the K values to the global memory cache.
1154
+ //
1155
+ // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
1156
+ // system. We designed it this way as it allows much better memory loads (and there are many
1157
+ // more loads) + the stores are really "write and forget" since we won't need the ack before
1158
+ // the end of the kernel. There's plenty of time for the transactions to complete.
1159
+
1160
+ // The 16B chunk written by the thread.
1161
+ int co = tidx / QK_VECS_IN_16B;
1162
+ // The position of the thread in that 16B chunk.
1163
+ int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
1164
+
1165
+ // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
1166
+ int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
1167
+ // params.timestep*QK_ELTS_IN_16B +
1168
+ tlength_circ * QK_ELTS_IN_16B + ci;
1169
+
1170
+ if (write_kv_cache) {
1171
+ // Trigger the stores to global memory.
1172
+ if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
1173
+ *reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k;
1174
+ }
1175
+ }
1176
+
1177
+ // Compute \sum_i Q[i] * K^T[i] for the current timestep.
1178
+ #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
1179
+ using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec>::Type;
1180
+ #else
1181
+ using Qk_vec_acum = Qk_vec;
1182
+ #endif
1183
+ qk = dot<Qk_vec_acum, Qk_vec>(q, k);
1184
+ if (QK_VECS_PER_WARP <= WARP_SIZE) {
1185
+ #pragma unroll
1186
+ for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
1187
+ qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
1188
+ }
1189
+ }
1190
+ }
1191
+
1192
+ if (QK_VECS_PER_WARP > WARP_SIZE) {
1193
+ constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
1194
+ qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
1195
+ }
1196
+
1197
+ // Store that value in shared memory. Keep the Q*K^T value in register for softmax.
1198
+ if (tidx == 0) {
1199
+ // Normalize qk.
1200
+ qk *= params.inv_sqrt_dh;
1201
+ if (params.relative_attention_bias != nullptr) {
1202
+ // TODO (Haotian): check whether we should replace hi with hi_kv,
1203
+ // although params.relative_attention_bias is usually not used.
1204
+ qk = add(qk,
1205
+ params.relative_attention_bias[hi * params.relative_attention_bias_stride
1206
+ * params.relative_attention_bias_stride
1207
+ + (tlength - padd_len) * params.relative_attention_bias_stride
1208
+ + (tlength - padd_len)]);
1209
+ }
1210
+ // Add alibi positional encoding
1211
+ // qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
1212
+ // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
1213
+
1214
+ qk_max = qk;
1215
+ qk_smem[tlength - first_step] = qk;
1216
+ // qk_smem[params.timestep] = qk;
1217
+ }
1218
+
1219
+ // Make sure the data is in shared memory.
1220
+ __syncthreads();
1221
+
1222
+ // The type of queries and keys for the math in the Q*K^T product.
1223
+ using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
1224
+ // The number of elements per vector.
1225
+ constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
1226
+ // Make sure the hidden size per head is a multiple of the vector size.
1227
+ static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
1228
+ // The number of elements per thread.
1229
+ constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
1230
+ // The number of vectors per thread.
1231
+ constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
1232
+
1233
+ // The position the first key loaded by each thread from the cache buffer (for this B * H).
1234
+ int ko = tidx / THREADS_PER_KEY;
1235
+ // The position of the thread in the chunk of keys.
1236
+ int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
1237
+
1238
+ static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
1239
+
1240
+ // Load the Q values from shared memory. The values are reused during the loop on K.
1241
+ K_vec q_vec[K_VECS_PER_THREAD];
1242
+ #pragma unroll
1243
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
1244
+ q_vec[ii] = *reinterpret_cast<const K_vec*>(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
1245
+ }
1246
+
1247
+ K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1];
1248
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
1249
+ #pragma unroll
1250
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
1251
+ k_bias_vec[ii] = *reinterpret_cast<const K_vec*>(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
1252
+ }
1253
+ }
1254
+
1255
+ // The number of timesteps loaded per iteration.
1256
+ constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
1257
+ // The number of keys per warp.
1258
+ constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
1259
+
1260
+ // The base pointer for the key in the cache buffer.
1261
+ T* k_cache = &params.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
1262
+ // Base pointer for the beam's batch, before offsetting with indirection buffer
1263
+ T* k_cache_batch = &params.k_cache[bbhi_kv * params.memory_max_len * Dh + ki];
1264
+
1265
+ // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
1266
+ // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
1267
+ int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
1268
+
1269
+ // prefix prompt length if has
1270
+ const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
1271
+
1272
+ // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
1273
+ const bool has_beams = params.cache_indir != nullptr;
1274
+ const int* beam_indices = has_beams ? &params.cache_indir[bi_seq_len_offset] : nullptr;
1275
+
1276
+ for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
1277
+ const int ti_circ = ti % params.memory_max_len;
1278
+
1279
+ // The keys loaded from the key cache.
1280
+ K_vec k[K_VECS_PER_THREAD];
1281
+ K_vec k_vec_zero;
1282
+ zero(k_vec_zero);
1283
+ #pragma unroll
1284
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
1285
+ int jj = ii * params.memory_max_len + ti_circ;
1286
+ // if( ti < params.timestep ) {
1287
+ const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
1288
+ if (ti < tlength) {
1289
+ if (!within_bounds) {
1290
+ k[ii] = k_vec_zero;
1291
+ }
1292
+ else {
1293
+ if (has_beams) {
1294
+ const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
1295
+ k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]);
1296
+ }
1297
+ else {
1298
+ k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[jj * QK_ELTS_IN_16B]);
1299
+ }
1300
+ }
1301
+ // add bias and update k_cache
1302
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
1303
+ k[ii] = add(k[ii], k_bias_vec[ii]);
1304
+
1305
+ if (do_ia3) {
1306
+ k[ii] = mul<K_vec, K_vec, K_vec>(
1307
+ k[ii],
1308
+ *reinterpret_cast<const K_vec*>(
1309
+ &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki
1310
+ + ii * THREADS_PER_KEY * K_VEC_SIZE]));
1311
+ }
1312
+
1313
+ if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) {
1314
+ *reinterpret_cast<K_vec*>(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii];
1315
+ }
1316
+ }
1317
+ }
1318
+ }
1319
+
1320
+ // Perform the dot product and normalize qk.
1321
+ //
1322
+ // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
1323
+ float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k) * params.inv_sqrt_dh;
1324
+ bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
1325
+
1326
+ // Store the product to shared memory. There's one qk value per timestep. Update the max.
1327
+ // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
1328
+ if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
1329
+ if (params.relative_attention_bias != nullptr) {
1330
+ qk = add(qk,
1331
+ params.relative_attention_bias[hi * params.relative_attention_bias_stride
1332
+ * params.relative_attention_bias_stride
1333
+ + tlength * params.relative_attention_bias_stride + ti]);
1334
+ }
1335
+ if (params.linear_bias_slopes != nullptr) {
1336
+ // Apply the linear position bias: (ki - qi) * slope[hi].
1337
+ // The padding token locates between the input context and the generated tokens.
1338
+ // We need to remove the number of padding tokens in the distance computation.
1339
+ // ti : 0 1 2 3 4 5 6 7 8 9(tlength)
1340
+ // token: i i i i p p p o o o where i=input, p=pad, o=output.
1341
+ // e.g. ti = 2, dist = (9 - 3) - 2 = 4.
1342
+ int max_context_length = params.max_prefix_prompt_length + params.max_input_length;
1343
+ float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength;
1344
+
1345
+ qk += mul<float, float, float>(params.linear_bias_slopes[hi], dist);
1346
+ }
1347
+ // Add alibi positional encoding
1348
+ // qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
1349
+ qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
1350
+ qk_smem[ti - first_step] = qk;
1351
+ }
1352
+ }
1353
+
1354
+ // Perform the final reduction to compute the max inside each warp.
1355
+ //
1356
+ // NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
1357
+ // group so it's not needed to run the reduction inside the group (again).
1358
+ #pragma unroll
1359
+ for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
1360
+ qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
1361
+ }
1362
+
1363
+ // Decompose the thread index into warp and lane.
1364
+ const int warp = tidx / WARP_SIZE;
1365
+ const int lane = tidx % WARP_SIZE;
1366
+
1367
+ // The warp leader writes the max to shared memory.
1368
+ if (lane == 0) {
1369
+ red_smem[warp] = qk_max;
1370
+ }
1371
+
1372
+ // Make sure the products are in shared memory.
1373
+ __syncthreads();
1374
+
1375
+ // The warps finalize the reduction.
1376
+ qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
1377
+ #pragma unroll
1378
+ for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
1379
+ qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
1380
+ }
1381
+
1382
+ // Broadcast to all the threads in the warp.
1383
+ qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
1384
+
1385
+ // Compute the logits and start the sum.
1386
+ float sum = 0.f;
1387
+ // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
1388
+ for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
1389
+ bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
1390
+ float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
1391
+ sum += logit;
1392
+ qk_smem[ti - first_step] = logit;
1393
+ }
1394
+
1395
+ // Compute the sum.
1396
+ sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
1397
+
1398
+ // Normalize the logits.
1399
+ float inv_sum = __fdividef(1.f, sum + 1.e-6f);
1400
+ // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
1401
+ const size_t cross_attention_out_offset =
1402
+ params.is_return_cross_attentions ?
1403
+ bhi_kv * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len :
1404
+ 0;
1405
+ for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
1406
+ float logit = qk_smem[ti - first_step] * inv_sum;
1407
+ if (params.is_return_cross_attentions) {
1408
+ params.cross_attention_out[cross_attention_out_offset + ti] = logit;
1409
+ }
1410
+ convert_from_float(logits_smem[ti - first_step], logit);
1411
+ }
1412
+
1413
+ // Put Values part below so we leverage __syncthreads
1414
+ // from the previous step
1415
+
1416
+ // The number of elements per vector.
1417
+ constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
1418
+ // A vector of V elements for the current timestep.
1419
+ using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
1420
+
1421
+ // The value computed by this thread.
1422
+ int vo = tidx / THREADS_PER_VALUE;
1423
+ // The hidden dimensions computed by this particular thread.
1424
+ int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
1425
+
1426
+ // The base pointer for the value in the cache buffer.
1427
+ T* v_cache = &params.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
1428
+ // Base pointer for the beam's batch, before offsetting with indirection buffer
1429
+ T* v_cache_batch = &params.v_cache[bbhi_kv * params.memory_max_len * Dh + vi];
1430
+
1431
+ // The number of values processed per iteration of the loop.
1432
+ constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
1433
+
1434
+ // One group of threads computes the product(s) for the current timestep.
1435
+ V_vec v_bias;
1436
+ zero(v_bias);
1437
+ // if( vo == params.timestep % V_PER_ITER ) {
1438
+ if (Dh == Dh_MAX || vi < Dh) {
1439
+ if (handle_kv) {
1440
+ if (vo == tlength % V_PER_ITER) {
1441
+ // Trigger the loads from the V bias buffer.
1442
+ if (params.v_bias != nullptr) {
1443
+ v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi_kv * Dh + vi]);
1444
+ }
1445
+ if (DO_CROSS_ATTENTION) {
1446
+ *reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
1447
+ }
1448
+ }
1449
+ }
1450
+ }
1451
+
1452
+ // From previous, before values, step
1453
+ // Also make sure the logits are in shared memory.
1454
+ __syncthreads();
1455
+
1456
+ // Values continued
1457
+ #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
1458
+ using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
1459
+ #else
1460
+ using V_vec_acum = V_vec;
1461
+ #endif
1462
+ // The partial outputs computed by each thread.
1463
+ V_vec_acum out;
1464
+ zero(out);
1465
+
1466
+ // Loop over the timesteps to compute the partial outputs.
1467
+ // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
1468
+ if (Dh == Dh_MAX || vi < Dh) {
1469
+ for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
1470
+ const int ti_circ = ti % params.memory_max_len;
1471
+
1472
+ // Fetch offset based on cache_indir when beam sampling
1473
+ const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
1474
+ const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh;
1475
+ // Load the values from the cache.
1476
+ V_vec v = *reinterpret_cast<const V_vec*>(&v_cache_batch[beam_offset + ti_circ * Dh]);
1477
+ if (DO_CROSS_ATTENTION && params.timestep == 0) {
1478
+ v = add(v, *reinterpret_cast<V_vec*>(&bias_smem[vi]));
1479
+ if (do_ia3) {
1480
+ v = mul<V_vec, V_vec, V_vec>(
1481
+ v,
1482
+ *reinterpret_cast<const V_vec*>(
1483
+ &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
1484
+ }
1485
+ *reinterpret_cast<V_vec*>(&v_cache[ti * Dh]) = v;
1486
+ }
1487
+ // Load the logits from shared memory.
1488
+ #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
1489
+ float logit = logits_smem[ti - first_step];
1490
+ out = fma(logit, cast_to_float(v), out);
1491
+ #else
1492
+ T logit = logits_smem[ti - first_step];
1493
+
1494
+ // Update the partial sums.
1495
+ out = fma(logit, v, out);
1496
+ #endif
1497
+ }
1498
+ }
1499
+
1500
+ // One group of threads computes the product(s) for the current timestep.
1501
+ // if( vo == params.timestep % V_PER_ITER ) {
1502
+ if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
1503
+
1504
+ V_vec v;
1505
+ if (DO_CROSS_ATTENTION) {
1506
+ v = *reinterpret_cast<const V_vec*>(&v_cache[tlength * Dh]);
1507
+ }
1508
+ else {
1509
+ // Trigger the loads from the V buffer.
1510
+ const auto v_offset = v_base_offset + vi;
1511
+ if (params.int8_mode == 2) {
1512
+ using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
1513
+ using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
1514
+ const auto v_scaling = params.qkv_scale_out[2];
1515
+ const auto v_quant =
1516
+ *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
1517
+
1518
+ convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
1519
+ }
1520
+ else {
1521
+ v = *reinterpret_cast<const V_vec*>(&params.v[v_offset]);
1522
+ }
1523
+ // Trigger the loads from the V bias buffer.
1524
+ // V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
1525
+ }
1526
+
1527
+ // Compute the V values with bias.
1528
+ v = add(v, v_bias);
1529
+ if (write_kv_cache) {
1530
+
1531
+ if (do_ia3) {
1532
+ v = mul<V_vec, V_vec, V_vec>(
1533
+ v,
1534
+ *reinterpret_cast<const V_vec*>(
1535
+ &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
1536
+ }
1537
+
1538
+ // Store the values with bias back to global memory in the cache for V.
1539
+ //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
1540
+ *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
1541
+ }
1542
+
1543
+ // Initialize the output value with the current timestep.
1544
+ #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
1545
+ // out = fma(logits_smem[params.timestep], cast_to_float(v), out);
1546
+ out = fma(logits_smem[tlength - first_step], cast_to_float(v), out);
1547
+ #else
1548
+ // out = fma(logits_smem[params.timestep], v, out);
1549
+ out = fma(logits_smem[tlength - first_step], v, out);
1550
+ #endif
1551
+ }
1552
+
1553
+ // Make sure we can start writing to shared memory.
1554
+ __syncthreads();
1555
+
1556
+ // Run the final reduction amongst the different groups computing different partial outputs.
1557
+ if (Dh == Dh_MAX || vi < Dh) {
1558
+ #pragma unroll
1559
+ for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
1560
+
1561
+ // The midpoint in the number of active groups.
1562
+ int midpoint = active_groups / 2;
1563
+
1564
+ // The upper part of active threads store to shared memory.
1565
+ if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
1566
+ #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
1567
+ convert_from_float(*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
1568
+ #else
1569
+ *reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
1570
+ #endif
1571
+ }
1572
+ __syncthreads();
1573
+
1574
+ // The bottom warps update their values.
1575
+ if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
1576
+ out = add(*reinterpret_cast<const V_vec*>(&out_smem[vo * Dh + vi]), out);
1577
+ }
1578
+ __syncthreads();
1579
+ }
1580
+ }
1581
+
1582
+ // Output the final values.
1583
+ if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
1584
+ #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
1585
+ if (params.int8_mode == 2) {
1586
+ using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
1587
+ out = mul<V_vec_acum, float>(*params.attention_out_scale, out);
1588
+ *reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
1589
+ cast_to_int8(out);
1590
+ }
1591
+ else {
1592
+ convert_from_float(*reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]), out);
1593
+ }
1594
+ #else
1595
+ // TODO: support int8_mode?
1596
+ *reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]) = out;
1597
+ #endif
1598
+ }
1599
+ }
1600
+
1601
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1602
+
1603
+ } // namespace mmha
1604
+
1605
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1606
+
1607
+ template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
1608
+ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream);
AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention_utils.h ADDED
@@ -0,0 +1,1786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Downloaded from from FasterTransformer v5.2.1
2
+ // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
3
+ /*
4
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #pragma once
20
+
21
+ #include "cuda_bf16_wrapper.h"
22
+ #include "cuda_bf16_fallbacks.cuh"
23
+ #include <stdint.h>
24
+
25
+ using namespace fastertransformer;
26
+
27
+ namespace mmha {
28
+
29
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
30
+
31
+ struct Float8_ {
32
+ float2 x;
33
+ float2 y;
34
+ float2 z;
35
+ float2 w;
36
+ };
37
+
38
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
39
+
40
+ struct Float4_ {
41
+ float2 x;
42
+ float2 y;
43
+ };
44
+
45
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ #ifdef ENABLE_BF16
48
+ struct bf16_4_t {
49
+ __nv_bfloat162 x;
50
+ __nv_bfloat162 y;
51
+ };
52
+
53
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ struct bf16_8_t {
56
+ __nv_bfloat162 x;
57
+ __nv_bfloat162 y;
58
+ __nv_bfloat162 z;
59
+ __nv_bfloat162 w;
60
+ };
61
+ #endif
62
+
63
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
64
+
65
+ template<typename T>
66
+ struct num_elems;
67
+ template<>
68
+ struct num_elems<float> {
69
+ static constexpr int value = 1;
70
+ };
71
+ template<>
72
+ struct num_elems<float2> {
73
+ static constexpr int value = 2;
74
+ };
75
+ template<>
76
+ struct num_elems<float4> {
77
+ static constexpr int value = 4;
78
+ };
79
+ template<>
80
+ struct num_elems<Float4_> {
81
+ static constexpr int value = 4;
82
+ };
83
+ template<>
84
+ struct num_elems<Float8_> {
85
+ static constexpr int value = 8;
86
+ };
87
+
88
+ template<>
89
+ struct num_elems<uint32_t> {
90
+ static constexpr int value = 2;
91
+ };
92
+ template<>
93
+ struct num_elems<uint2> {
94
+ static constexpr int value = 4;
95
+ };
96
+ template<>
97
+ struct num_elems<uint4> {
98
+ static constexpr int value = 8;
99
+ };
100
+
101
+ #ifdef ENABLE_BF16
102
+ template<>
103
+ struct num_elems<__nv_bfloat162> {
104
+ static constexpr int value = 2;
105
+ };
106
+ template<>
107
+ struct num_elems<bf16_4_t> {
108
+ static constexpr int value = 4;
109
+ };
110
+ template<>
111
+ struct num_elems<bf16_8_t> {
112
+ static constexpr int value = 8;
113
+ };
114
+ #endif
115
+
116
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
117
+
118
+ template<typename T, int N>
119
+ struct packed_type;
120
+ template<typename T>
121
+ struct packed_type<T, 1> {
122
+ using type = T;
123
+ };
124
+ template<>
125
+ struct packed_type<int8_t, 2> {
126
+ using type = int16_t;
127
+ };
128
+ template<>
129
+ struct packed_type<int8_t, 4> {
130
+ using type = int32_t;
131
+ };
132
+ template<>
133
+ struct packed_type<int8_t, 8> {
134
+ using type = int64_t;
135
+ };
136
+
137
+ template<>
138
+ struct packed_type<float, 2> {
139
+ using type = float2;
140
+ };
141
+ template<>
142
+ struct packed_type<float, 4> {
143
+ using type = float4;
144
+ };
145
+ template<>
146
+ struct packed_type<float, 8> {
147
+ using type = Float8_;
148
+ };
149
+
150
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
151
+
152
+ inline __device__ float add(float a, float b)
153
+ {
154
+ return a + b;
155
+ }
156
+
157
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
158
+
159
+ inline __device__ float2 add(float2 a, float2 b)
160
+ {
161
+ float2 c;
162
+ c.x = add(a.x, b.x);
163
+ c.y = add(a.y, b.y);
164
+ return c;
165
+ }
166
+
167
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
168
+
169
+ inline __device__ float4 add(float4 a, float4 b)
170
+ {
171
+ float4 c;
172
+ c.x = add(a.x, b.x);
173
+ c.y = add(a.y, b.y);
174
+ c.z = add(a.z, b.z);
175
+ c.w = add(a.w, b.w);
176
+ return c;
177
+ }
178
+
179
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
180
+
181
+ #ifdef ENABLE_BF16
182
+ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
183
+ {
184
+ return a + b;
185
+ }
186
+
187
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
188
+
189
+ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
190
+ {
191
+ return bf16hadd2(a, b);
192
+ }
193
+
194
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
195
+
196
+ inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b)
197
+ {
198
+ bf16_4_t c;
199
+ c.x = add(a.x, b.x);
200
+ c.y = add(a.y, b.y);
201
+ return c;
202
+ }
203
+
204
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
205
+
206
+ inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b)
207
+ {
208
+ bf16_8_t c;
209
+ c.x = add(a.x, b.x);
210
+ c.y = add(a.y, b.y);
211
+ c.z = add(a.z, b.z);
212
+ c.w = add(a.w, b.w);
213
+ return c;
214
+ }
215
+ #endif // ENABLE_BF16
216
+
217
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
218
+
219
+ inline __device__ uint16_t add(uint16_t a, uint16_t b)
220
+ {
221
+ uint16_t c;
222
+ asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
223
+ return c;
224
+ }
225
+
226
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
227
+
228
+ inline __device__ uint32_t add(uint32_t a, uint32_t b)
229
+ {
230
+ uint32_t c;
231
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
232
+ return c;
233
+ }
234
+
235
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
236
+
237
+ inline __device__ uint2 add(uint2 a, uint2 b)
238
+ {
239
+ uint2 c;
240
+ c.x = add(a.x, b.x);
241
+ c.y = add(a.y, b.y);
242
+ return c;
243
+ }
244
+
245
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
246
+
247
+ inline __device__ uint4 add(uint4 a, uint4 b)
248
+ {
249
+ uint4 c;
250
+ c.x = add(a.x, b.x);
251
+ c.y = add(a.y, b.y);
252
+ c.z = add(a.z, b.z);
253
+ c.w = add(a.w, b.w);
254
+ return c;
255
+ }
256
+
257
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
258
+
259
+ inline __device__ uint16_t float_to_half(float f)
260
+ {
261
+ union {
262
+ uint32_t u32;
263
+ uint16_t u16[2];
264
+ } tmp;
265
+ #if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better?
266
+ float zero = 0.f;
267
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
268
+ #else
269
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
270
+ #endif
271
+ return tmp.u16[0];
272
+ }
273
+
274
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
275
+
276
+ inline __device__ uint32_t float2_to_half2(float2 f)
277
+ {
278
+ union {
279
+ uint32_t u32;
280
+ uint16_t u16[2];
281
+ } tmp;
282
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
283
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
284
+ #else
285
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
286
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
287
+ #endif
288
+ return tmp.u32;
289
+ }
290
+
291
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
292
+
293
+ inline __device__ float half_to_float(uint16_t h)
294
+ {
295
+ float f;
296
+ asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
297
+ return f;
298
+ }
299
+
300
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
301
+
302
+ inline __device__ float2 half2_to_float2(uint32_t v)
303
+ {
304
+ uint16_t lo, hi;
305
+ asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
306
+ return make_float2(half_to_float(lo), half_to_float(hi));
307
+ }
308
+
309
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
310
+
311
+ inline __device__ float add(float a, uint16_t b)
312
+ {
313
+ return a + half_to_float(b);
314
+ }
315
+
316
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
317
+
318
+ #ifdef ENABLE_BF16
319
+ inline __device__ float add(float a, __nv_bfloat16 b)
320
+ {
321
+ return a + __bfloat162float(b);
322
+ }
323
+ #endif
324
+
325
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
326
+
327
+ inline __device__ float2 add(uint32_t a, float2 fb)
328
+ {
329
+ float2 fa = half2_to_float2(a);
330
+ return add(fa, fb);
331
+ }
332
+
333
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
334
+
335
+ inline __device__ Float4_ add(uint2 a, Float4_ fb)
336
+ {
337
+ Float4_ fc;
338
+ fc.x = add(a.x, fb.x);
339
+ fc.y = add(a.y, fb.y);
340
+ return fc;
341
+ }
342
+
343
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
344
+
345
+ inline __device__ Float8_ add(uint4 a, Float8_ fb)
346
+ {
347
+ Float8_ fc;
348
+ fc.x = add(a.x, fb.x);
349
+ fc.y = add(a.y, fb.y);
350
+ fc.z = add(a.z, fb.z);
351
+ fc.w = add(a.w, fb.w);
352
+ return fc;
353
+ }
354
+
355
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
356
+
357
+ inline __device__ uint32_t h0_h0(uint16_t a)
358
+ {
359
+ uint32_t b;
360
+ asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
361
+ return b;
362
+ }
363
+
364
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
365
+
366
+ inline __device__ float fma(float a, float b, float c)
367
+ {
368
+ return a * b + c;
369
+ }
370
+
371
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
372
+
373
+ inline __device__ float2 fma(float2 a, float2 b, float2 c)
374
+ {
375
+ float2 d;
376
+ d.x = fma(a.x, b.x, c.x);
377
+ d.y = fma(a.y, b.y, c.y);
378
+ return d;
379
+ }
380
+
381
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
382
+
383
+ inline __device__ float2 fma(float a, float2 b, float2 c)
384
+ {
385
+ float2 d;
386
+ d.x = fma(a, b.x, c.x);
387
+ d.y = fma(a, b.y, c.y);
388
+ return d;
389
+ }
390
+
391
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
392
+
393
+ inline __device__ float4 fma(float4 a, float4 b, float4 c)
394
+ {
395
+ float4 d;
396
+ d.x = fma(a.x, b.x, c.x);
397
+ d.y = fma(a.y, b.y, c.y);
398
+ d.z = fma(a.z, b.z, c.z);
399
+ d.w = fma(a.w, b.w, c.w);
400
+ return d;
401
+ }
402
+
403
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
404
+
405
+ inline __device__ float4 fma(float a, float4 b, float4 c)
406
+ {
407
+ float4 d;
408
+ d.x = fma(a, b.x, c.x);
409
+ d.y = fma(a, b.y, c.y);
410
+ d.z = fma(a, b.z, c.z);
411
+ d.w = fma(a, b.w, c.w);
412
+ return d;
413
+ }
414
+
415
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
416
+
417
+ inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c)
418
+ {
419
+ Float4_ d;
420
+ d.x = fma(a, b.x, c.x);
421
+ d.y = fma(a, b.y, c.y);
422
+ return d;
423
+ }
424
+
425
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
426
+
427
+ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c)
428
+ {
429
+ Float8_ d;
430
+ d.x = fma(a, b.x, c.x);
431
+ d.y = fma(a, b.y, c.y);
432
+ d.z = fma(a, b.z, c.z);
433
+ d.w = fma(a, b.w, c.w);
434
+ return d;
435
+ }
436
+
437
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
438
+
439
+ #ifdef ENABLE_BF16
440
+ inline __device__ float2 add(__nv_bfloat162 a, float2 fb)
441
+ {
442
+ float2 fa = bf1622float2(a);
443
+ return add(fa, fb);
444
+ }
445
+
446
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
447
+
448
+ inline __device__ Float4_ add(bf16_4_t a, Float4_ fb)
449
+ {
450
+ Float4_ fc;
451
+ fc.x = add(a.x, fb.x);
452
+ fc.y = add(a.y, fb.y);
453
+ return fc;
454
+ }
455
+
456
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
457
+
458
+ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb)
459
+ {
460
+ Float8_ fc;
461
+ fc.x = add(a.x, fb.x);
462
+ fc.y = add(a.y, fb.y);
463
+ fc.z = add(a.z, fb.z);
464
+ fc.w = add(a.w, fb.w);
465
+ return fc;
466
+ }
467
+ #endif // ENABLE_BF16
468
+
469
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
470
+
471
+ inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c)
472
+ {
473
+ uint32_t d;
474
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
475
+ return d;
476
+ }
477
+
478
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
479
+
480
+ inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c)
481
+ {
482
+ return fma(h0_h0(a), b, c);
483
+ }
484
+
485
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
486
+
487
+ inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c)
488
+ {
489
+ uint2 d;
490
+ d.x = fma(a.x, b.x, c.x);
491
+ d.y = fma(a.y, b.y, c.y);
492
+ return d;
493
+ }
494
+
495
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
496
+
497
+ inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c)
498
+ {
499
+ uint32_t s = h0_h0(a);
500
+ uint2 d;
501
+ d.x = fma(s, b.x, c.x);
502
+ d.y = fma(s, b.y, c.y);
503
+ return d;
504
+ }
505
+
506
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
507
+
508
+ inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c)
509
+ {
510
+ uint4 d;
511
+ d.x = fma(a.x, b.x, c.x);
512
+ d.y = fma(a.y, b.y, c.y);
513
+ d.z = fma(a.z, b.z, c.z);
514
+ d.w = fma(a.w, b.w, c.w);
515
+ return d;
516
+ }
517
+
518
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
519
+
520
+ inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c)
521
+ {
522
+ uint32_t s = h0_h0(a);
523
+ uint4 d;
524
+ d.x = fma(s, b.x, c.x);
525
+ d.y = fma(s, b.y, c.y);
526
+ d.z = fma(s, b.z, c.z);
527
+ d.w = fma(s, b.w, c.w);
528
+ return d;
529
+ }
530
+
531
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
532
+
533
+ inline __device__ float fma(uint16_t a, uint16_t b, float fc)
534
+ {
535
+ float fa = half_to_float(a);
536
+ float fb = half_to_float(b);
537
+ return fa * fb + fc;
538
+ }
539
+
540
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
541
+
542
+ inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc)
543
+ {
544
+ float2 fa = half2_to_float2(a);
545
+ float2 fb = half2_to_float2(b);
546
+ return fma(fa, fb, fc);
547
+ }
548
+
549
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
550
+
551
+ inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc)
552
+ {
553
+ return fma(h0_h0(a), b, fc);
554
+ }
555
+
556
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
557
+
558
+ inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc)
559
+ {
560
+ Float4_ fd;
561
+ fd.x = fma(a.x, b.x, fc.x);
562
+ fd.y = fma(a.y, b.y, fc.y);
563
+ return fd;
564
+ }
565
+
566
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
567
+
568
+ inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc)
569
+ {
570
+ uint32_t s = h0_h0(a);
571
+ Float4_ fd;
572
+ fd.x = fma(s, b.x, fc.x);
573
+ fd.y = fma(s, b.y, fc.y);
574
+ return fd;
575
+ }
576
+
577
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
578
+
579
+ inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc)
580
+ {
581
+ Float8_ fd;
582
+ fd.x = fma(a.x, b.x, fc.x);
583
+ fd.y = fma(a.y, b.y, fc.y);
584
+ fd.z = fma(a.z, b.z, fc.z);
585
+ fd.w = fma(a.w, b.w, fc.w);
586
+ return fd;
587
+ }
588
+
589
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
590
+
591
+ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc)
592
+ {
593
+ uint32_t s = h0_h0(a);
594
+ Float8_ fd;
595
+ fd.x = fma(s, b.x, fc.x);
596
+ fd.y = fma(s, b.y, fc.y);
597
+ fd.z = fma(s, b.z, fc.z);
598
+ fd.w = fma(s, b.w, fc.w);
599
+ return fd;
600
+ }
601
+
602
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
603
+ #ifdef ENABLE_BF16
604
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
605
+ {
606
+ return bf16hfma2(a, b, c);
607
+ }
608
+
609
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
610
+
611
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c)
612
+ {
613
+ return bf16hfma2(bf162bf162(a), b, c);
614
+ }
615
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
616
+
617
+ inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c)
618
+ {
619
+ bf16_4_t d;
620
+ d.x = fma(a.x, b.x, c.x);
621
+ d.y = fma(a.y, b.y, c.y);
622
+ return d;
623
+ }
624
+
625
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
626
+
627
+ inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c)
628
+ {
629
+ __nv_bfloat162 s = bf162bf162(a);
630
+ bf16_4_t d;
631
+ d.x = fma(s, b.x, c.x);
632
+ d.y = fma(s, b.y, c.y);
633
+ return d;
634
+ }
635
+
636
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
637
+
638
+ inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c)
639
+ {
640
+ bf16_8_t d;
641
+ d.x = fma(a.x, b.x, c.x);
642
+ d.y = fma(a.y, b.y, c.y);
643
+ d.z = fma(a.z, b.z, c.z);
644
+ d.w = fma(a.w, b.w, c.w);
645
+ return d;
646
+ }
647
+
648
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
649
+
650
+ inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c)
651
+ {
652
+ __nv_bfloat162 s = bf162bf162(a);
653
+ bf16_8_t d;
654
+ d.x = fma(s, b.x, c.x);
655
+ d.y = fma(s, b.y, c.y);
656
+ d.z = fma(s, b.z, c.z);
657
+ d.w = fma(s, b.w, c.w);
658
+ return d;
659
+ }
660
+
661
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
662
+
663
+ inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc)
664
+ {
665
+ return __bfloat162float(a) * __bfloat162float(b) + fc;
666
+ }
667
+
668
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
669
+
670
+ inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc)
671
+ {
672
+ float2 fa = bf1622float2(a);
673
+ float2 fb = bf1622float2(b);
674
+ return fma(fa, fb, fc);
675
+ }
676
+
677
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
678
+
679
+ inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc)
680
+ {
681
+ return fma(bf162bf162(a), b, fc);
682
+ }
683
+
684
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
685
+
686
+ inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc)
687
+ {
688
+ Float4_ fd;
689
+ fd.x = fma(a.x, b.x, fc.x);
690
+ fd.y = fma(a.y, b.y, fc.y);
691
+ return fd;
692
+ }
693
+
694
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
695
+
696
+ inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc)
697
+ {
698
+ __nv_bfloat162 s = bf162bf162(a);
699
+ Float4_ fd;
700
+ fd.x = fma(s, b.x, fc.x);
701
+ fd.y = fma(s, b.y, fc.y);
702
+ return fd;
703
+ }
704
+
705
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
706
+
707
+ inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc)
708
+ {
709
+ Float8_ fd;
710
+ fd.x = fma(a.x, b.x, fc.x);
711
+ fd.y = fma(a.y, b.y, fc.y);
712
+ fd.z = fma(a.z, b.z, fc.z);
713
+ fd.w = fma(a.w, b.w, fc.w);
714
+ return fd;
715
+ }
716
+
717
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
718
+
719
+ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc)
720
+ {
721
+ __nv_bfloat162 s = bf162bf162(a);
722
+ Float8_ fd;
723
+ fd.x = fma(s, b.x, fc.x);
724
+ fd.y = fma(s, b.y, fc.y);
725
+ fd.z = fma(s, b.z, fc.z);
726
+ fd.w = fma(s, b.w, fc.w);
727
+ return fd;
728
+ }
729
+ #endif // ENABLE_BF16
730
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
731
+
732
+ template<typename Acc, typename A, typename B>
733
+ inline __device__ Acc mul(A a, B b)
734
+ {
735
+ return a * b;
736
+ }
737
+
738
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
739
+
740
+ template<>
741
+ inline __device__ float mul<float, float>(float a, float b)
742
+ {
743
+ return a * b;
744
+ }
745
+
746
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
747
+
748
+ template<>
749
+ inline __device__ float2 mul(float2 a, float2 b)
750
+ {
751
+ float2 c;
752
+ c.x = a.x * b.x;
753
+ c.y = a.y * b.y;
754
+ return c;
755
+ }
756
+
757
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
758
+
759
+ template<>
760
+ inline __device__ float2 mul(float a, float2 b)
761
+ {
762
+ float2 c;
763
+ c.x = a * b.x;
764
+ c.y = a * b.y;
765
+ return c;
766
+ }
767
+
768
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
769
+
770
+ template<>
771
+ inline __device__ float4 mul(float4 a, float4 b)
772
+ {
773
+ float4 c;
774
+ c.x = a.x * b.x;
775
+ c.y = a.y * b.y;
776
+ c.z = a.z * b.z;
777
+ c.w = a.w * b.w;
778
+ return c;
779
+ }
780
+
781
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
782
+
783
+ template<>
784
+ inline __device__ float4 mul(float a, float4 b)
785
+ {
786
+ float4 c;
787
+ c.x = a * b.x;
788
+ c.y = a * b.y;
789
+ c.z = a * b.z;
790
+ c.w = a * b.w;
791
+ return c;
792
+ }
793
+
794
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
795
+
796
+ template<>
797
+ inline __device__ Float8_ mul(float a, Float8_ b)
798
+ {
799
+ Float8_ c;
800
+ c.x = make_float2(a * b.x.x, a * b.x.y);
801
+ c.y = make_float2(a * b.y.x, a * b.y.y);
802
+ c.z = make_float2(a * b.z.x, a * b.z.y);
803
+ c.w = make_float2(a * b.w.x, a * b.w.y);
804
+ return c;
805
+ }
806
+
807
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
808
+
809
+ template<>
810
+ inline __device__ uint16_t mul(uint16_t a, uint16_t b)
811
+ {
812
+ uint16_t c;
813
+ asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
814
+ return c;
815
+ }
816
+
817
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
818
+
819
+ template<>
820
+ inline __device__ uint32_t mul(uint32_t a, uint32_t b)
821
+ {
822
+ uint32_t c;
823
+ asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
824
+ return c;
825
+ }
826
+
827
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
828
+
829
+ template<>
830
+ inline __device__ uint32_t mul(uint16_t a, uint32_t b)
831
+ {
832
+ return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
833
+ }
834
+
835
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
836
+
837
+ template<>
838
+ inline __device__ uint2 mul(uint2 a, uint2 b)
839
+ {
840
+ uint2 c;
841
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
842
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
843
+ return c;
844
+ }
845
+
846
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
847
+
848
+ template<>
849
+ inline __device__ uint2 mul(uint16_t a, uint2 b)
850
+ {
851
+ uint32_t s = h0_h0(a);
852
+ uint2 c;
853
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
854
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
855
+ return c;
856
+ }
857
+
858
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
859
+
860
+ template<>
861
+ inline __device__ uint4 mul(uint4 a, uint4 b)
862
+ {
863
+ uint4 c;
864
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
865
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
866
+ c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
867
+ c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
868
+ return c;
869
+ }
870
+
871
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
872
+
873
+ template<>
874
+ inline __device__ uint4 mul(uint16_t a, uint4 b)
875
+ {
876
+ uint32_t s = h0_h0(a);
877
+ uint4 c;
878
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
879
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
880
+ c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
881
+ c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
882
+ return c;
883
+ }
884
+
885
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
886
+
887
+ template<>
888
+ inline __device__ float mul(uint16_t a, uint16_t b)
889
+ {
890
+ float fa = half_to_float(a);
891
+ float fb = half_to_float(b);
892
+ return fa * fb;
893
+ }
894
+
895
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
896
+
897
+ template<>
898
+ inline __device__ float mul(uint16_t a, float b)
899
+ {
900
+ return half_to_float(a) * b;
901
+ }
902
+
903
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
904
+
905
+ template<>
906
+ inline __device__ float2 mul(uint32_t a, uint32_t b)
907
+ {
908
+ float2 fa = half2_to_float2(a);
909
+ float2 fb = half2_to_float2(b);
910
+ return mul<float2, float2, float2>(fa, fb);
911
+ }
912
+
913
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
914
+
915
+ template<>
916
+ inline __device__ float2 mul(uint16_t a, uint32_t b)
917
+ {
918
+ return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
919
+ }
920
+
921
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
922
+
923
+ template<>
924
+ inline __device__ Float4_ mul(uint2 a, uint2 b)
925
+ {
926
+ Float4_ fc;
927
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
928
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
929
+ return fc;
930
+ }
931
+
932
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
933
+
934
+ template<>
935
+ inline __device__ Float4_ mul(uint16_t a, uint2 b)
936
+ {
937
+ uint32_t s = h0_h0(a);
938
+ Float4_ fc;
939
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
940
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
941
+ return fc;
942
+ }
943
+
944
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
945
+
946
+ template<>
947
+ inline __device__ Float8_ mul(uint4 a, uint4 b)
948
+ {
949
+ Float8_ fc;
950
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
951
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
952
+ fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
953
+ fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
954
+ return fc;
955
+ }
956
+
957
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
958
+
959
+ template<>
960
+ inline __device__ Float8_ mul(uint16_t a, uint4 b)
961
+ {
962
+ uint32_t s = h0_h0(a);
963
+ Float8_ fc;
964
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
965
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
966
+ fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
967
+ fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
968
+ return fc;
969
+ }
970
+
971
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
972
+
973
+ #ifdef ENABLE_BF16
974
+ template<>
975
+ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b)
976
+ {
977
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
978
+ return __hmul(a, b);
979
+ #else
980
+ return bf16hmul(a, b);
981
+ #endif
982
+ }
983
+
984
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
985
+
986
+ template<>
987
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
988
+ {
989
+ return bf16hmul2(a, b);
990
+ }
991
+
992
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
993
+
994
+ template<>
995
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
996
+ {
997
+ return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
998
+ }
999
+
1000
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1001
+
1002
+ template<>
1003
+ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b)
1004
+ {
1005
+ bf16_4_t c;
1006
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
1007
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
1008
+ return c;
1009
+ }
1010
+
1011
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1012
+
1013
+ template<>
1014
+ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
1015
+ {
1016
+ __nv_bfloat162 s = bf162bf162(a);
1017
+ bf16_4_t c;
1018
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
1019
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
1020
+ return c;
1021
+ }
1022
+
1023
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1024
+
1025
+ template<>
1026
+ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
1027
+ {
1028
+ bf16_8_t c;
1029
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
1030
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
1031
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
1032
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
1033
+ return c;
1034
+ }
1035
+
1036
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1037
+
1038
+ template<>
1039
+ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b)
1040
+ {
1041
+ __nv_bfloat162 s = bf162bf162(a);
1042
+ bf16_8_t c;
1043
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
1044
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
1045
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
1046
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
1047
+ return c;
1048
+ }
1049
+
1050
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1051
+
1052
+ template<>
1053
+ inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b)
1054
+ {
1055
+ float fa = (float)a;
1056
+ float fb = (float)b;
1057
+ return fa * fb;
1058
+ }
1059
+
1060
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1061
+
1062
+ template<>
1063
+ inline __device__ float mul(__nv_bfloat16 a, float b)
1064
+ {
1065
+ return __bfloat162float(a) * b;
1066
+ }
1067
+
1068
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1069
+
1070
+ template<>
1071
+ inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b)
1072
+ {
1073
+ float2 fa = bf1622float2(a);
1074
+ float2 fb = bf1622float2(b);
1075
+ return mul<float2, float2, float2>(fa, fb);
1076
+ }
1077
+
1078
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1079
+
1080
+ template<>
1081
+ inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b)
1082
+ {
1083
+ return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
1084
+ }
1085
+
1086
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1087
+
1088
+ template<>
1089
+ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b)
1090
+ {
1091
+ Float4_ fc;
1092
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
1093
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
1094
+ return fc;
1095
+ }
1096
+
1097
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1098
+
1099
+ template<>
1100
+ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b)
1101
+ {
1102
+ __nv_bfloat162 s = bf162bf162(a);
1103
+ Float4_ fc;
1104
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
1105
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
1106
+ return fc;
1107
+ }
1108
+
1109
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1110
+
1111
+ template<>
1112
+ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b)
1113
+ {
1114
+ Float8_ fc;
1115
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
1116
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
1117
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
1118
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
1119
+ return fc;
1120
+ }
1121
+
1122
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1123
+
1124
+ template<>
1125
+ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b)
1126
+ {
1127
+ __nv_bfloat162 s = bf162bf162(a);
1128
+ Float8_ fc;
1129
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
1130
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
1131
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
1132
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
1133
+ return fc;
1134
+ }
1135
+ #endif // ENABLE_BF16
1136
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1137
+
1138
+ inline __device__ float sum(float v)
1139
+ {
1140
+ return v;
1141
+ }
1142
+
1143
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1144
+
1145
+ inline __device__ float sum(float2 v)
1146
+ {
1147
+ return v.x + v.y;
1148
+ }
1149
+
1150
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1151
+
1152
+ inline __device__ float sum(float4 v)
1153
+ {
1154
+ return v.x + v.y + v.z + v.w;
1155
+ }
1156
+
1157
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1158
+
1159
+ #ifdef ENABLE_BF16
1160
+ inline __device__ float sum(__nv_bfloat162 v)
1161
+ {
1162
+ float2 vf = bf1622float2(v);
1163
+ return vf.x + vf.y;
1164
+ }
1165
+
1166
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1167
+
1168
+ inline __device__ float sum(bf16_4_t v)
1169
+ {
1170
+ return sum(v.x) + sum(v.y);
1171
+ }
1172
+
1173
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1174
+
1175
+ inline __device__ float sum(bf16_8_t v)
1176
+ {
1177
+ return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
1178
+ }
1179
+ #endif // ENABLE_BF16
1180
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1181
+
1182
+ inline __device__ float sum(uint16_t v)
1183
+ {
1184
+ return half_to_float(v);
1185
+ }
1186
+
1187
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1188
+
1189
+ inline __device__ float sum(uint32_t v)
1190
+ {
1191
+ float2 tmp = half2_to_float2(v);
1192
+ return tmp.x + tmp.y;
1193
+ }
1194
+
1195
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1196
+
1197
+ inline __device__ float sum(uint2 v)
1198
+ {
1199
+ uint32_t c = add(v.x, v.y);
1200
+ return sum(c);
1201
+ }
1202
+
1203
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1204
+
1205
+ inline __device__ float sum(uint4 v)
1206
+ {
1207
+ #if 1
1208
+ uint32_t c = add(v.x, v.y);
1209
+ c = add(c, v.z);
1210
+ c = add(c, v.w);
1211
+ #else
1212
+ uint32_t c = add(v.x, v.y);
1213
+ uint32_t d = add(v.z, v.w);
1214
+ c = add(c, d);
1215
+ #endif
1216
+ return sum(c);
1217
+ }
1218
+
1219
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1220
+
1221
+ inline __device__ float sum(Float4_ v)
1222
+ {
1223
+ return v.x.x + v.x.y + v.y.x + v.y.y;
1224
+ }
1225
+
1226
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1227
+
1228
+ inline __device__ float sum(Float8_ v)
1229
+ {
1230
+ return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
1231
+ }
1232
+
1233
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1234
+
1235
+ template<typename T>
1236
+ inline __device__ float dot(T a, T b)
1237
+ {
1238
+ return sum(mul<T, T, T>(a, b));
1239
+ }
1240
+
1241
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1242
+
1243
+ template<typename A, typename T>
1244
+ inline __device__ float dot(T a, T b)
1245
+ {
1246
+ return sum(mul<A, T, T>(a, b));
1247
+ }
1248
+
1249
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1250
+
1251
+ inline __device__ void zero(uint16_t& dst)
1252
+ {
1253
+ dst = uint16_t(0);
1254
+ }
1255
+
1256
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1257
+
1258
+ template<typename T>
1259
+ inline __device__ void zero(T& dst)
1260
+ {
1261
+ constexpr int WORDS = sizeof(T) / 4;
1262
+ union {
1263
+ T raw;
1264
+ uint32_t words[WORDS];
1265
+ } tmp;
1266
+ #pragma unroll
1267
+ for (int ii = 0; ii < WORDS; ++ii) {
1268
+ tmp.words[ii] = 0u;
1269
+ }
1270
+ dst = tmp.raw;
1271
+ }
1272
+
1273
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
1274
+
1275
+ inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base)
1276
+ {
1277
+ const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
1278
+ return {cos(inv_freq), sin(inv_freq)};
1279
+ }
1280
+
1281
+ inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
1282
+ {
1283
+ float2 rot_v;
1284
+ rot_v.x = coef.x * v.x - coef.y * v.y;
1285
+ rot_v.y = coef.x * v.y + coef.y * v.x;
1286
+ return rot_v;
1287
+ }
1288
+
1289
+ inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef)
1290
+ {
1291
+ float2 fv = half2_to_float2(v);
1292
+ float2 rot_fv = rotary_embedding_transform(fv, coef);
1293
+ return float2_to_half2(rot_fv);
1294
+ }
1295
+
1296
+ #ifdef ENABLE_BF16
1297
+ inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef)
1298
+ {
1299
+ float2 fv = bf1622float2(v);
1300
+ float2 rot_fv = rotary_embedding_transform(fv, coef);
1301
+ return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
1302
+ }
1303
+ #endif
1304
+
1305
+ inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
1306
+ {
1307
+ return;
1308
+ }
1309
+
1310
+ inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
1311
+ {
1312
+ return;
1313
+ }
1314
+
1315
+ inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1316
+ {
1317
+ if (2 * tid >= rot_embed_dim) {
1318
+ return;
1319
+ }
1320
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1321
+ q = rotary_embedding_transform(q, coef);
1322
+ }
1323
+
1324
+ inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1325
+ {
1326
+ if (2 * tid >= rot_embed_dim) {
1327
+ return;
1328
+ }
1329
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1330
+ q = rotary_embedding_transform(q, coef);
1331
+ k = rotary_embedding_transform(k, coef);
1332
+ }
1333
+
1334
+ inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1335
+ {
1336
+ if (4 * tid >= rot_embed_dim) {
1337
+ return;
1338
+ }
1339
+
1340
+ Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
1341
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1342
+ q_.x = rotary_embedding_transform(q_.x, coef0);
1343
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1344
+ q_.y = rotary_embedding_transform(q_.y, coef1);
1345
+ }
1346
+
1347
+ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1348
+ {
1349
+ if (4 * tid >= rot_embed_dim) {
1350
+ return;
1351
+ }
1352
+
1353
+ Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
1354
+ Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
1355
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1356
+ q_.x = rotary_embedding_transform(q_.x, coef0);
1357
+ k_.x = rotary_embedding_transform(k_.x, coef0);
1358
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1359
+ q_.y = rotary_embedding_transform(q_.y, coef1);
1360
+ k_.y = rotary_embedding_transform(k_.y, coef1);
1361
+ }
1362
+
1363
+ inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1364
+ {
1365
+ if (2 * tid >= rot_embed_dim) {
1366
+ return;
1367
+ }
1368
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1369
+ q = rotary_embedding_transform(q, coef);
1370
+ }
1371
+
1372
+ inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1373
+ {
1374
+ if (2 * tid >= rot_embed_dim) {
1375
+ return;
1376
+ }
1377
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1378
+ q = rotary_embedding_transform(q, coef);
1379
+ k = rotary_embedding_transform(k, coef);
1380
+ }
1381
+
1382
+ inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1383
+ {
1384
+ if (4 * tid >= rot_embed_dim) {
1385
+ return;
1386
+ }
1387
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1388
+ q.x = rotary_embedding_transform(q.x, coef0);
1389
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1390
+ q.y = rotary_embedding_transform(q.y, coef1);
1391
+ }
1392
+
1393
+ inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1394
+ {
1395
+ if (4 * tid >= rot_embed_dim) {
1396
+ return;
1397
+ }
1398
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1399
+ q.x = rotary_embedding_transform(q.x, coef0);
1400
+ k.x = rotary_embedding_transform(k.x, coef0);
1401
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1402
+ q.y = rotary_embedding_transform(q.y, coef1);
1403
+ k.y = rotary_embedding_transform(k.y, coef1);
1404
+ }
1405
+
1406
+ inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1407
+ {
1408
+ if (8 * tid >= rot_embed_dim) {
1409
+ return;
1410
+ }
1411
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
1412
+ q.x = rotary_embedding_transform(q.x, coef0);
1413
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
1414
+ q.y = rotary_embedding_transform(q.y, coef1);
1415
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
1416
+ q.z = rotary_embedding_transform(q.z, coef2);
1417
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
1418
+ q.w = rotary_embedding_transform(q.w, coef3);
1419
+ }
1420
+
1421
+ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1422
+ {
1423
+ if (8 * tid >= rot_embed_dim) {
1424
+ return;
1425
+ }
1426
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
1427
+ q.x = rotary_embedding_transform(q.x, coef0);
1428
+ k.x = rotary_embedding_transform(k.x, coef0);
1429
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
1430
+ q.y = rotary_embedding_transform(q.y, coef1);
1431
+ k.y = rotary_embedding_transform(k.y, coef1);
1432
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
1433
+ q.z = rotary_embedding_transform(q.z, coef2);
1434
+ k.z = rotary_embedding_transform(k.z, coef2);
1435
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
1436
+ q.w = rotary_embedding_transform(q.w, coef3);
1437
+ k.w = rotary_embedding_transform(k.w, coef3);
1438
+ }
1439
+
1440
+ #ifdef ENABLE_BF16
1441
+ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1442
+ {
1443
+ if (2 * tid >= rot_embed_dim) {
1444
+ return;
1445
+ }
1446
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1447
+ q = rotary_embedding_transform(q, coef);
1448
+ }
1449
+
1450
+ inline __device__ void
1451
+ apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1452
+ {
1453
+ if (2 * tid >= rot_embed_dim) {
1454
+ return;
1455
+ }
1456
+ const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
1457
+ q = rotary_embedding_transform(q, coef);
1458
+ k = rotary_embedding_transform(k, coef);
1459
+ }
1460
+
1461
+ inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1462
+ {
1463
+ if (4 * tid >= rot_embed_dim) {
1464
+ return;
1465
+ }
1466
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1467
+ q.x = rotary_embedding_transform(q.x, coef0);
1468
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1469
+ q.y = rotary_embedding_transform(q.y, coef1);
1470
+ }
1471
+
1472
+ inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1473
+ {
1474
+ if (4 * tid >= rot_embed_dim) {
1475
+ return;
1476
+ }
1477
+ const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
1478
+ q.x = rotary_embedding_transform(q.x, coef0);
1479
+ k.x = rotary_embedding_transform(k.x, coef0);
1480
+ const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
1481
+ q.y = rotary_embedding_transform(q.y, coef1);
1482
+ k.y = rotary_embedding_transform(k.y, coef1);
1483
+ }
1484
+
1485
+ inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1486
+ {
1487
+ if (8 * tid >= rot_embed_dim) {
1488
+ return;
1489
+ }
1490
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
1491
+ q.x = rotary_embedding_transform(q.x, coef0);
1492
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
1493
+ q.y = rotary_embedding_transform(q.y, coef1);
1494
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
1495
+ q.z = rotary_embedding_transform(q.z, coef2);
1496
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
1497
+ q.w = rotary_embedding_transform(q.w, coef3);
1498
+ }
1499
+
1500
+ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
1501
+ {
1502
+ if (8 * tid >= rot_embed_dim) {
1503
+ return;
1504
+ }
1505
+ const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
1506
+ q.x = rotary_embedding_transform(q.x, coef0);
1507
+ k.x = rotary_embedding_transform(k.x, coef0);
1508
+ const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
1509
+ q.y = rotary_embedding_transform(q.y, coef1);
1510
+ k.y = rotary_embedding_transform(k.y, coef1);
1511
+ const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
1512
+ q.z = rotary_embedding_transform(q.z, coef2);
1513
+ k.z = rotary_embedding_transform(k.z, coef2);
1514
+ const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
1515
+ q.w = rotary_embedding_transform(q.w, coef3);
1516
+ k.w = rotary_embedding_transform(k.w, coef3);
1517
+ }
1518
+ #endif // ENABLE_BF16
1519
+
1520
+ template<typename Vec_T, typename T>
1521
+ __device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
1522
+
1523
+ template<>
1524
+ __device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch)
1525
+ {
1526
+ return;
1527
+ }
1528
+
1529
+ template<>
1530
+ __device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1531
+ {
1532
+ union {
1533
+ uint32_t u32;
1534
+ uint16_t u16[2];
1535
+ } tmp;
1536
+ tmp.u16[0] = smem[transpose_idx];
1537
+ tmp.u16[1] = smem[smem_pitch + transpose_idx];
1538
+
1539
+ vec = tmp.u32;
1540
+ }
1541
+
1542
+ template<>
1543
+ __device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1544
+ {
1545
+ union {
1546
+ uint32_t u32;
1547
+ uint16_t u16[2];
1548
+ } tmp_1, tmp_2;
1549
+ tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
1550
+ tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
1551
+
1552
+ union {
1553
+ uint2 u32x2;
1554
+ uint16_t u16[4];
1555
+ } tmp_3;
1556
+ tmp_3.u16[0] = tmp_1.u16[0];
1557
+ tmp_3.u16[1] = tmp_2.u16[0];
1558
+ tmp_3.u16[2] = tmp_1.u16[1];
1559
+ tmp_3.u16[3] = tmp_2.u16[1];
1560
+
1561
+ vec = tmp_3.u32x2;
1562
+ }
1563
+
1564
+ template<>
1565
+ __device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1566
+ {
1567
+ union {
1568
+ uint64_t u64;
1569
+ uint16_t u16[4];
1570
+ } tmp_1, tmp_2;
1571
+ tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
1572
+ tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
1573
+
1574
+ union {
1575
+ uint4 u32x4;
1576
+ uint16_t u16[8];
1577
+ } tmp_3;
1578
+ tmp_3.u16[0] = tmp_1.u16[0];
1579
+ tmp_3.u16[1] = tmp_2.u16[0];
1580
+ tmp_3.u16[2] = tmp_1.u16[1];
1581
+ tmp_3.u16[3] = tmp_2.u16[1];
1582
+ tmp_3.u16[4] = tmp_1.u16[2];
1583
+ tmp_3.u16[5] = tmp_2.u16[2];
1584
+ tmp_3.u16[6] = tmp_1.u16[3];
1585
+ tmp_3.u16[7] = tmp_2.u16[3];
1586
+
1587
+ vec = tmp_3.u32x4;
1588
+ }
1589
+
1590
+ #ifdef ENABLE_BF16
1591
+ template<>
1592
+ __device__ __inline__ void
1593
+ vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1594
+ {
1595
+ union {
1596
+ uint32_t u32;
1597
+ __nv_bfloat16 bf16[2];
1598
+ } tmp_1, tmp_2;
1599
+ tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
1600
+ tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
1601
+
1602
+ vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
1603
+ vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
1604
+ }
1605
+
1606
+ template<>
1607
+ __device__ __inline__ void
1608
+ vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1609
+ {
1610
+ union {
1611
+ uint64_t u64;
1612
+ __nv_bfloat16 bf16[4];
1613
+ } tmp_1, tmp_2;
1614
+ tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
1615
+ tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
1616
+
1617
+ vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
1618
+ vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
1619
+ vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]};
1620
+ vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]};
1621
+ }
1622
+ #endif // ENABLE_BF16
1623
+
1624
+ template<>
1625
+ __device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch)
1626
+ {
1627
+ vec.x = smem[transpose_idx];
1628
+ vec.z = smem[transpose_idx + 1];
1629
+ vec.y = smem[smem_pitch + transpose_idx];
1630
+ vec.w = smem[smem_pitch + transpose_idx + 1];
1631
+ }
1632
+
1633
+ template<>
1634
+ __device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
1635
+ {
1636
+ union {
1637
+ uint32_t u32;
1638
+ half u16[2];
1639
+ } tmp;
1640
+ tmp.u16[0] = smem[transpose_idx];
1641
+ tmp.u16[1] = smem[smem_pitch + transpose_idx];
1642
+
1643
+ vec = tmp.u32;
1644
+ }
1645
+
1646
+ #ifdef ENABLE_BF16
1647
+ template<>
1648
+ __device__ __inline__ void
1649
+ vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1650
+ {
1651
+ vec.x = smem[transpose_idx];
1652
+ vec.y = smem[smem_pitch + transpose_idx];
1653
+ }
1654
+ #endif
1655
+
1656
+ template<>
1657
+ __device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch)
1658
+ {
1659
+ vec.x = smem[transpose_idx];
1660
+ vec.y = smem[smem_pitch + transpose_idx];
1661
+ }
1662
+
1663
+ template<typename Vec_T, typename T>
1664
+ __device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
1665
+
1666
+ template<>
1667
+ __device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch)
1668
+ {
1669
+ return;
1670
+ }
1671
+
1672
+ template<>
1673
+ __device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1674
+ {
1675
+ union {
1676
+ uint64_t u64;
1677
+ uint16_t u16[4];
1678
+ } tmp_1, tmp_2;
1679
+
1680
+ union {
1681
+ uint4 u32x4;
1682
+ uint16_t u16[8];
1683
+ } tmp_3;
1684
+ tmp_3.u32x4 = vec;
1685
+ tmp_1.u16[0] = tmp_3.u16[0];
1686
+ tmp_2.u16[0] = tmp_3.u16[1];
1687
+ tmp_1.u16[1] = tmp_3.u16[2];
1688
+ tmp_2.u16[1] = tmp_3.u16[3];
1689
+ tmp_1.u16[2] = tmp_3.u16[4];
1690
+ tmp_2.u16[2] = tmp_3.u16[5];
1691
+ tmp_1.u16[3] = tmp_3.u16[6];
1692
+ tmp_2.u16[3] = tmp_3.u16[7];
1693
+
1694
+ *reinterpret_cast<uint64_t*>(&smem[transpose_idx]) = tmp_1.u64;
1695
+ *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u64;
1696
+ }
1697
+
1698
+ template<>
1699
+ __device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1700
+ {
1701
+ union {
1702
+ uint32_t u32;
1703
+ uint16_t u16[2];
1704
+ } tmp_1, tmp_2;
1705
+
1706
+ union {
1707
+ uint2 u32x2;
1708
+ uint16_t u16[4];
1709
+ } tmp_3;
1710
+ tmp_3.u32x2 = vec;
1711
+ tmp_1.u16[0] = tmp_3.u16[0];
1712
+ tmp_2.u16[0] = tmp_3.u16[1];
1713
+ tmp_1.u16[1] = tmp_3.u16[2];
1714
+ tmp_2.u16[1] = tmp_3.u16[3];
1715
+
1716
+ *reinterpret_cast<uint32_t*>(&smem[transpose_idx]) = tmp_1.u32;
1717
+ *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u32;
1718
+ }
1719
+
1720
+ template<>
1721
+ __device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
1722
+ {
1723
+ union {
1724
+ uint32_t u32;
1725
+ uint16_t u16[2];
1726
+ } tmp;
1727
+ tmp.u32 = vec;
1728
+
1729
+ smem[transpose_idx] = tmp.u16[0];
1730
+ smem[smem_pitch + transpose_idx] = tmp.u16[1];
1731
+ }
1732
+
1733
+ template<>
1734
+ __device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch)
1735
+ {
1736
+ smem[transpose_idx] = vec.x;
1737
+ smem[transpose_idx + 1] = vec.z;
1738
+ smem[smem_pitch + transpose_idx] = vec.y;
1739
+ smem[smem_pitch + transpose_idx + 1] = vec.w;
1740
+ }
1741
+
1742
+ template<>
1743
+ __device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
1744
+ {
1745
+ union {
1746
+ uint32_t u32;
1747
+ half u16[2];
1748
+ } tmp;
1749
+
1750
+ tmp.u32 = vec;
1751
+ smem[transpose_idx] = tmp.u16[0];
1752
+ smem[smem_pitch + transpose_idx] = tmp.u16[1];
1753
+ }
1754
+
1755
+ #ifdef ENABLE_BF16
1756
+ template<>
1757
+ __device__ __inline__ void
1758
+ write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1759
+ {
1760
+ smem[transpose_idx] = vec.x;
1761
+ smem[smem_pitch + transpose_idx] = vec.y;
1762
+ }
1763
+
1764
+ template<>
1765
+ __device__ __inline__ void
1766
+ write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1767
+ {
1768
+ write_smem_transpose(reinterpret_cast<const uint2&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
1769
+ }
1770
+
1771
+ template<>
1772
+ __device__ __inline__ void
1773
+ write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
1774
+ {
1775
+ write_smem_transpose(reinterpret_cast<const uint4&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
1776
+ }
1777
+ #endif
1778
+
1779
+ template<>
1780
+ __device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch)
1781
+ {
1782
+ smem[transpose_idx] = vec.x;
1783
+ smem[smem_pitch + transpose_idx] = vec.y;
1784
+ }
1785
+
1786
+ } // namespace mmha
AutoAWQ_kernels/awq_ext/attention/ft_attention.cpp ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from NVIDIA/FasterTransformer and FlashAttention
2
+
3
+ #include <torch/extension.h>
4
+ #include "ATen/cuda/CUDAContext.h"
5
+ #include <c10/cuda/CUDAGuard.h>
6
+
7
+ #include "ft_attention.h"
8
+ #include "decoder_masked_multihead_attention.h"
9
+
10
+ #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
11
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
12
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
+
14
+ #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
15
+ if (TYPE == at::ScalarType::Half) { \
16
+ using scalar_t = at::Half; \
17
+ __VA_ARGS__(); \
18
+ } else if (TYPE == at::ScalarType::BFloat16) { \
19
+ using scalar_t = at::BFloat16; \
20
+ __VA_ARGS__(); \
21
+ } else if (TYPE == at::ScalarType::Float) { \
22
+ using scalar_t = float; \
23
+ __VA_ARGS__(); \
24
+ } else { \
25
+ AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
26
+ }
27
+
28
+ template<typename T>
29
+ void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
30
+ const cudaStream_t& stream);
31
+
32
+ template<typename T>
33
+ void cross_multihead_attention(const Masked_multihead_attention_params<T>& params,
34
+ const cudaStream_t& stream);
35
+
36
+ template<typename T>
37
+ struct SATypeConverter {
38
+ using Type = T;
39
+ };
40
+
41
+ template<>
42
+ struct SATypeConverter<at::Half> {
43
+ using Type = uint16_t;
44
+ };
45
+
46
+ template<>
47
+ struct SATypeConverter<at::BFloat16> {
48
+ using Type = __nv_bfloat16;
49
+ };
50
+
51
+ template <typename T>
52
+ void set_params(Masked_multihead_attention_params<T> &params,
53
+ const size_t batch_size,
54
+ const size_t nheads,
55
+ const size_t nheads_kv,
56
+ const size_t memory_max_seqlen,
57
+ const size_t headdim,
58
+ const int timestep,
59
+ const int rotary_embedding_dim,
60
+ const float rotary_base,
61
+ const bool neox_rotary_style,
62
+ const int qkv_batch_stride,
63
+ T *q_ptr,
64
+ T *k_ptr,
65
+ T *v_ptr,
66
+ T *k_cache_ptr,
67
+ T *v_cache_ptr,
68
+ int *length_per_sample,
69
+ float *alibi_slopes_ptr,
70
+ T *out_ptr) {
71
+ // Reset the parameters
72
+ memset(&params, 0, sizeof(params));
73
+ params.q = q_ptr;
74
+ params.k = k_ptr;
75
+ params.v = v_ptr;
76
+ params.q_bias = nullptr;
77
+ params.k_bias = nullptr;
78
+ params.v_bias = nullptr;
79
+ params.k_cache = k_cache_ptr;
80
+ params.v_cache = v_cache_ptr;
81
+ params.linear_bias_slopes = alibi_slopes_ptr;
82
+ params.out = out_ptr;
83
+ params.cache_indir = nullptr;
84
+ params.stride = qkv_batch_stride;
85
+ params.batch_size = batch_size;
86
+ params.beam_width = 1;
87
+ params.memory_max_len = memory_max_seqlen;
88
+ params.num_heads = nheads;
89
+ params.num_kv_heads = nheads_kv;
90
+ params.hidden_size_per_head = headdim;
91
+ params.rotary_embedding_dim = rotary_embedding_dim;
92
+ params.rotary_base = rotary_base;
93
+ params.neox_rotary_style = neox_rotary_style;
94
+ params.timestep = timestep;
95
+ params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
96
+ params.total_padding_tokens = nullptr;
97
+ params.masked_tokens = nullptr;
98
+ params.prefix_prompt_lengths = nullptr;
99
+ params.max_prefix_prompt_length = 0;
100
+ params.relative_attention_bias = nullptr;
101
+ params.relative_attention_bias_stride = 0;
102
+ params.cross_attention_out = nullptr;
103
+ params.max_decoder_seq_len = 0;
104
+ params.is_return_cross_attentions = false;
105
+ params.finished = nullptr;
106
+ params.memory_length_per_sample = nullptr;
107
+ params.length_per_sample = length_per_sample;
108
+ }
109
+
110
+ torch::Tensor single_query_attention(const torch::Tensor q,
111
+ const torch::Tensor k,
112
+ const torch::Tensor v,
113
+ torch::Tensor k_cache,
114
+ torch::Tensor v_cache,
115
+ c10::optional<const torch::Tensor> length_per_sample_,
116
+ c10::optional<const torch::Tensor> alibi_slopes_,
117
+ const int timestep,
118
+ const int rotary_embedding_dim,
119
+ const float rotary_base,
120
+ // neox_rotary_style = not interleaved
121
+ const bool neox_rotary_style) {
122
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
123
+ int batch_size = v_cache.size(0);
124
+ int nheads = q.size(1);
125
+ int nheads_kv = v_cache.size(1);
126
+ int memory_max_seqlen = v_cache.size(2);
127
+ int headdim = v_cache.size(3);
128
+ CHECK_SHAPE(q, batch_size, nheads, headdim);
129
+ CHECK_SHAPE(k, batch_size, nheads_kv, headdim);
130
+ CHECK_SHAPE(v, batch_size, nheads_kv, headdim);
131
+ CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim);
132
+ // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
133
+ int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
134
+ CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize);
135
+ TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
136
+ TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
137
+ TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
138
+ // TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
139
+ CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
140
+
141
+ if (length_per_sample_.has_value()) {
142
+ auto length_per_sample = length_per_sample_.value();
143
+ CHECK_DEVICE(length_per_sample);
144
+ CHECK_SHAPE(length_per_sample, batch_size);
145
+ CHECK_CONTIGUOUS(length_per_sample);
146
+ TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
147
+ }
148
+
149
+ if (alibi_slopes_.has_value()) {
150
+ auto alibi_slopes = alibi_slopes_.value();
151
+ CHECK_DEVICE(alibi_slopes);
152
+ CHECK_SHAPE(alibi_slopes, nheads);
153
+ CHECK_CONTIGUOUS(alibi_slopes);
154
+ TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32);
155
+ }
156
+
157
+ // Otherwise the kernel will be launched from cuda:0 device
158
+ // Cast to char to avoid compiler warning about narrowing
159
+ at::cuda::CUDAGuard device_guard{(char)q.get_device()};
160
+
161
+ torch::Tensor out = torch::empty_like(q);
162
+
163
+ DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
164
+ using DataType = typename SATypeConverter<scalar_t>::Type;
165
+ Masked_multihead_attention_params<DataType> params;
166
+ set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim,
167
+ timestep, rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
168
+ reinterpret_cast<DataType*>(q.data_ptr()),
169
+ reinterpret_cast<DataType*>(k.data_ptr()),
170
+ reinterpret_cast<DataType*>(v.data_ptr()),
171
+ reinterpret_cast<DataType*>(k_cache.data_ptr()),
172
+ reinterpret_cast<DataType*>(v_cache.data_ptr()),
173
+ length_per_sample_.has_value()
174
+ ? length_per_sample_.value().data_ptr<int>() : nullptr,
175
+ alibi_slopes_.has_value()
176
+ ? alibi_slopes_.value().data_ptr<float>(): nullptr,
177
+ reinterpret_cast<DataType*>(out.data_ptr()));
178
+ auto stream = at::cuda::getCurrentCUDAStream();
179
+ masked_multihead_attention(params, stream);
180
+ });
181
+ return out;
182
+ }
AutoAWQ_kernels/awq_ext/attention/ft_attention.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <torch/extension.h>
3
+
4
+
5
+ torch::Tensor single_query_attention(const torch::Tensor q,
6
+ const torch::Tensor k,
7
+ const torch::Tensor v,
8
+ torch::Tensor k_cache,
9
+ torch::Tensor v_cache,
10
+ c10::optional<const torch::Tensor> length_per_sample_,
11
+ c10::optional<const torch::Tensor> alibi_slopes_,
12
+ const int timestep,
13
+ const int rotary_embedding_dim = 0,
14
+ const float rotary_base = 10000.0f,
15
+ const bool neox_rotary_style=true);
AutoAWQ_kernels/awq_ext/exllama/cu_compat.cuh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #ifndef _cuda_compat_cuh
4
+ #define _cuda_compat_cuh
5
+
6
+ // atomicAdd for half types, to support CC < 7.x
7
+
8
+ __device__ __forceinline__ void atomicAdd_half(half* address, half val)
9
+ {
10
+ unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
11
+ unsigned int old = *address_as_ui;
12
+ unsigned int assumed;
13
+
14
+ do
15
+ {
16
+ assumed = old;
17
+ __half_raw hsum;
18
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
19
+ half tmpres = __hadd(hsum, val);
20
+ hsum = __half_raw(tmpres);
21
+ old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
22
+ old = atomicCAS(address_as_ui, assumed, old);
23
+ }
24
+ while (assumed != old);
25
+ }
26
+
27
+ // atomicAdd for half2 types
28
+
29
+ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
30
+ {
31
+ unsigned int* address_as_ui = (unsigned int*)address;
32
+ unsigned int old = *address_as_ui;
33
+ unsigned int assumed;
34
+ do
35
+ {
36
+ assumed = old;
37
+ half2 old_val = *((half2*)&old);
38
+ half2 new_val = __hadd2(old_val, val);
39
+ old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
40
+ }
41
+ while (assumed != old);
42
+ }
43
+
44
+ //
45
+
46
+ #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
47
+ #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
48
+
49
+ __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
50
+
51
+ #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
52
+ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
53
+ #endif
54
+
55
+ #endif
56
+ #endif
57
+
58
+ #endif
AutoAWQ_kernels/awq_ext/exllama/cuda_buffers.cu ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #define _cuda_buffers_cu
4
+ #include "cuda_buffers.cuh"
5
+
6
+ CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
7
+ // __constant__ half2 q4_table[16][256];
8
+ // half2 q4_table_host[16][256];
9
+ // bool q4_table_init = false;
10
+
11
+ CudaBuffers::CudaBuffers
12
+ (
13
+ int _device,
14
+ int _temp_state_size,
15
+ half* _temp_state,
16
+ half* _temp_dq
17
+ ) :
18
+ device(_device),
19
+ temp_state_size(_temp_state_size),
20
+ temp_state(_temp_state),
21
+ temp_dq(_temp_dq)
22
+ {
23
+ cudaSetDevice(_device);
24
+
25
+ cudaStreamCreate(&alt_stream_1);
26
+ cudaStreamCreate(&alt_stream_2);
27
+ cudaStreamCreate(&alt_stream_3);
28
+ cudaEventCreate(&alt_stream_1_done);
29
+ cudaEventCreate(&alt_stream_2_done);
30
+ cudaEventCreate(&alt_stream_3_done);
31
+ }
32
+
33
+ CudaBuffers::~CudaBuffers()
34
+ {
35
+ cudaStreamDestroy(alt_stream_1);
36
+ cudaStreamDestroy(alt_stream_2);
37
+ cudaStreamDestroy(alt_stream_3);
38
+ cudaEventDestroy(alt_stream_1_done);
39
+ cudaEventDestroy(alt_stream_2_done);
40
+ cudaEventDestroy(alt_stream_3_done);
41
+ }
42
+
43
+ CudaBuffers* get_buffers(const int device_index)
44
+ {
45
+ return g_buffers[device_index];
46
+ }
47
+
48
+ void prepare_buffers_cuda
49
+ (
50
+ int _device,
51
+ int _temp_state_size,
52
+ half* _temp_state,
53
+ half* _temp_dq
54
+ )
55
+ {
56
+ CudaBuffers* buffers = new CudaBuffers
57
+ (
58
+ _device,
59
+ _temp_state_size,
60
+ _temp_state,
61
+ _temp_dq
62
+ );
63
+
64
+ g_buffers[_device] = buffers;
65
+ }
66
+
67
+ void cleanup_buffers_cuda()
68
+ {
69
+ for (int i = 0; i < CUDA_MAX_DEVICES; i++)
70
+ {
71
+ if (!g_buffers[i]) continue;
72
+ delete g_buffers[i];
73
+ g_buffers[i] = NULL;
74
+ }
75
+ }
AutoAWQ_kernels/awq_ext/exllama/cuda_buffers.cuh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #ifndef _cuda_buffers_cuh
4
+ #define _cuda_buffers_cuh
5
+
6
+ #include <cuda_runtime.h>
7
+ #include <cuda_fp16.h>
8
+ #include <cstdint>
9
+ #include <cstdio>
10
+
11
+ const int CUDA_MAX_DEVICES = 16;
12
+
13
+ // #ifndef _cuda_buffers_cu
14
+ // extern __constant__ half2 q4_table[16][256];
15
+ // #endif
16
+
17
+ class CudaBuffers
18
+ {
19
+ public:
20
+ int device;
21
+
22
+ half* temp_state; // [max_hidden_rows * intermediate_size]
23
+ int temp_state_size;
24
+ half* temp_dq; // size of largest quant tensor * 8
25
+
26
+ cudaStream_t alt_stream_1;
27
+ cudaStream_t alt_stream_2;
28
+ cudaStream_t alt_stream_3;
29
+ cudaEvent_t alt_stream_1_done;
30
+ cudaEvent_t alt_stream_2_done;
31
+ cudaEvent_t alt_stream_3_done;
32
+
33
+ CudaBuffers
34
+ (
35
+ int _device,
36
+ int _temp_state_size,
37
+ half* _temp_state,
38
+ half* _temp_dq
39
+ );
40
+ ~CudaBuffers();
41
+ };
42
+
43
+ CudaBuffers* get_buffers(const int device_index);
44
+
45
+ void prepare_buffers_cuda
46
+ (
47
+ int _device,
48
+ int _temp_state_size,
49
+ half* _temp_state,
50
+ half* _temp_dq
51
+ );
52
+
53
+ void cleanup_buffers_cuda();
54
+
55
+ #endif
AutoAWQ_kernels/awq_ext/exllama/cuda_func/column_remap.cu ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #include "column_remap.cuh"
4
+ #include "../util.cuh"
5
+
6
+ const int SHUF_BLOCKSIZE_X = 256;
7
+ const int SHUF_BLOCKSIZE_Y = 16;
8
+
9
+ __global__ void column_remap_kernel
10
+ (
11
+ const half* __restrict__ x,
12
+ half* __restrict__ x_new,
13
+ const int x_width,
14
+ const int x_height,
15
+ const uint32_t* x_map
16
+ )
17
+ {
18
+ int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
19
+ int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
20
+ if (x_column >= x_width) return;
21
+ //if (x_row >= x_height) return;
22
+
23
+ int x_stride = x_width;
24
+ int x_idx = x_row * x_stride + x_column;
25
+
26
+ int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
27
+ int x_idx_end = x_row_end * x_stride + x_column;
28
+
29
+ int s_column = x_map[x_column];
30
+ int s_idx = x_row * x_stride + s_column;
31
+
32
+ while (x_idx < x_idx_end)
33
+ {
34
+ x_new[x_idx] = x[s_idx];
35
+ x_idx += x_stride;
36
+ s_idx += x_stride;
37
+ }
38
+ }
39
+
40
+ // Remap columns in x to correspond to sequential group index before matmul
41
+ //
42
+ // perform x -> seq_x such that seq_x @ seq_w == x @ w
43
+
44
+ void column_remap_cuda
45
+ (
46
+ const half* x,
47
+ half* x_new,
48
+ const int x_height,
49
+ const int x_width,
50
+ const uint32_t* x_map
51
+ )
52
+ {
53
+ dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
54
+
55
+ dim3 blocks
56
+ (
57
+ (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
58
+ (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
59
+ 1
60
+ );
61
+
62
+ column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
63
+ }
AutoAWQ_kernels/awq_ext/exllama/cuda_func/column_remap.cuh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #ifndef _column_remap_cuh
4
+ #define _column_remap_cuh
5
+
6
+ #include <cuda_runtime.h>
7
+ #include <cuda_fp16.h>
8
+ #include <cstdint>
9
+
10
+ void column_remap_cuda
11
+ (
12
+ const half* x,
13
+ half* x_new,
14
+ const int x_height,
15
+ const int x_width,
16
+ const uint32_t* x_map
17
+ );
18
+
19
+ #endif
AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matmul.cu ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #include "q4_matmul.cuh"
4
+ #include "column_remap.cuh"
5
+ #include "../util.cuh"
6
+ #include "../matrix.cuh"
7
+ #include "../cu_compat.cuh"
8
+ #include "../cuda_buffers.cuh"
9
+ #if defined(USE_ROCM)
10
+ #include "../hip_compat.cuh"
11
+ #endif
12
+
13
+ const int THREADS_X = 32; // Block size and thread count along columns in w and out
14
+ const int THREADS_Y = 1; // Block size and thread count along rows in x and out
15
+
16
+ typedef void (*fp_q4_matmul_kernel)
17
+ (
18
+ const half*,
19
+ const uint32_t*,
20
+ half*,
21
+ const half*,
22
+ const uint32_t*,
23
+ const int,
24
+ const int,
25
+ const int,
26
+ const int,
27
+ const int,
28
+ const uint32_t*,
29
+ bool
30
+ );
31
+
32
+ template<bool use_half2, bool use_groupsize, bool use_x_map>
33
+ __global__ void q4_matmul_kernel
34
+ (
35
+ const half* __restrict__ x,
36
+ const uint32_t* __restrict__ w,
37
+ half* __restrict__ out,
38
+ const half* __restrict__ w_scales,
39
+ const uint32_t* __restrict__ w_zeros,
40
+ const int height,
41
+ const int dim,
42
+ const int width,
43
+ const int groupsize,
44
+ const int block_size_z,
45
+ const uint32_t* __restrict__ x_map,
46
+ bool no_zero
47
+ )
48
+ {
49
+ // Start of block
50
+
51
+ int x_column = block_size_z * blockIdx.z;
52
+ int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
53
+
54
+ int w_column = THREADS_X * blockIdx.x + threadIdx.x;
55
+ int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
56
+
57
+ int iterations = (x_column_end - x_column) / 8;
58
+
59
+ // Views
60
+
61
+ MatrixView_half x_(x, height, dim);
62
+ MatrixView_half w_scales_(w_scales, dim / groupsize, width);
63
+ MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
64
+ MatrixView_q4_column w_(w, dim, width);
65
+ MatrixView_half_rw out_(out, height, width);
66
+
67
+ // Zero output
68
+
69
+ if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
70
+ {
71
+ *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
72
+ __syncthreads();
73
+ }
74
+
75
+ // Loop over part of x row (and w column)
76
+
77
+ half2 acc = {};
78
+ half acc_h = {};
79
+
80
+ if constexpr (use_groupsize)
81
+ {
82
+ // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
83
+ // could be slightly faster
84
+
85
+ for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
86
+ {
87
+ if constexpr (use_half2)
88
+ {
89
+ half2 w_scale = w_scales_.item_half2half2(group, w_column);
90
+ uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;
91
+
92
+ if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
93
+ else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
94
+ }
95
+ else
96
+ {
97
+ half w_scale = w_scales_.item(group, w_column);
98
+ uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;
99
+
100
+ if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
101
+ else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
102
+ }
103
+ }
104
+ }
105
+ else
106
+ {
107
+ // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
108
+
109
+ for (int k = x_column; k < x_column + iterations * 8; k += 8)
110
+ {
111
+ if constexpr (use_half2)
112
+ {
113
+ int group = k / groupsize;
114
+ half2 w_scale = w_scales_.item_half2half2(group, w_column);
115
+ uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;
116
+
117
+ if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
118
+ else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
119
+ }
120
+ else
121
+ {
122
+ int group = k / groupsize;
123
+ half w_scale = w_scales_.item(group, w_column);
124
+ uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;
125
+
126
+ if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
127
+ else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
128
+ }
129
+ }
130
+ }
131
+
132
+ // Add to block result
133
+
134
+ if constexpr (use_half2)
135
+ {
136
+ half result = __hadd(__low2half(acc), __high2half(acc));
137
+ atomicAdd(out_.item_ptr(x_row, w_column), result);
138
+ }
139
+ else
140
+ {
141
+ atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
142
+ }
143
+ }
144
+
145
+ fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
146
+ {
147
+ // <bool use_half2, bool use_groupsize, bool use_x_map>
148
+ if (tuningParams->matmul_no_half2) {
149
+ if (block_size_z % groupsize == 0) {
150
+ if (x_map) return q4_matmul_kernel<false, true, true >;
151
+ else return q4_matmul_kernel<false, true, false>;
152
+ } else {
153
+ if (x_map) return q4_matmul_kernel<false, false, true >;
154
+ else return q4_matmul_kernel<false, false, false>;
155
+ }
156
+ } else {
157
+ if (block_size_z % groupsize == 0)
158
+ {
159
+ if (x_map) return q4_matmul_kernel<true, true, true >;
160
+ else return q4_matmul_kernel<true, true, false>;
161
+ } else {
162
+ if (x_map) return q4_matmul_kernel<true, false, true >;
163
+ else return q4_matmul_kernel<true, false, false>;
164
+ }
165
+ }
166
+ };
167
+
168
+ // Compute y = x @ w
169
+
170
+ void q4_matmul_cuda
171
+ (
172
+ ExLlamaTuning* tuningParams,
173
+ const half* x,
174
+ const int x_height,
175
+ const Q4Matrix* w,
176
+ half* out,
177
+ bool no_zero,
178
+ cudaStream_t alt_stream
179
+ )
180
+ {
181
+ int height = x_height;
182
+ int dim = w->height;
183
+ int width = w->width;
184
+
185
+ cudaSetDevice(w->device);
186
+
187
+ uint32_t* x_map = w->cuda_x_map;
188
+ const half* x_mapped = x;
189
+ if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
190
+ {
191
+ CudaBuffers* buffers = get_buffers(w->device);
192
+ column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
193
+ x_mapped = buffers->temp_state;
194
+ x_map = NULL;
195
+ }
196
+
197
+ int block_size_z;
198
+ if (w->width == 4096) block_size_z = 384; // 7B
199
+ else if (w->width == 11008) block_size_z = 256;
200
+ else if (w->width == 5120) block_size_z = 384; // 13B
201
+ else if (w->width == 13824) block_size_z = 256;
202
+ else if (w->width == 6656) block_size_z = 256; // 33B
203
+ else if (w->width == 17920) block_size_z = 128;
204
+ else block_size_z = 256;
205
+
206
+ //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
207
+
208
+ dim3 threads(THREADS_X, THREADS_Y, 1);
209
+
210
+ dim3 blocks
211
+ (
212
+ (width + threads.x - 1) / threads.x,
213
+ (height + threads.y - 1) / threads.y,
214
+ (dim + block_size_z - 1) / block_size_z
215
+ );
216
+
217
+ fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
218
+
219
+ kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
220
+ }
221
+
222
+ void q4_matmul_recons_cuda
223
+ (
224
+ ExLlamaTuning* tuningParams,
225
+ const half* x,
226
+ const int x_height,
227
+ Q4Matrix* w,
228
+ half* out,
229
+ const cublasHandle_t handle,
230
+ bool no_zero
231
+ )
232
+ {
233
+ int height = x_height;
234
+ int dim = w->height;
235
+ int width = w->width;
236
+
237
+ cudaSetDevice(w->device);
238
+ CudaBuffers* buffers = get_buffers(w->device);
239
+
240
+ const half* x_mapped = x;
241
+ if (w->cuda_x_map)
242
+ {
243
+ TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend for GPTQ with act-order. Please call the exllama_set_max_input_length function to increase the buffer size for a sequence length >=", x_height, ":\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, max_input_length=", x_height, ")");
244
+ column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
245
+ x_mapped = buffers->temp_state;
246
+ }
247
+
248
+ w->reconstruct(buffers->temp_dq);
249
+
250
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
251
+ const float alpha = 1.0f;
252
+ const float beta = no_zero ? 1.0f : 0.0f;
253
+ cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
254
+ x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
255
+ #else
256
+ const half alpha = __float2half(1.0f);
257
+ const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
258
+ cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
259
+ #endif
260
+ }
AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matmul.cuh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #ifndef _q4_matmul_cuh
4
+ #define _q4_matmul_cuh
5
+
6
+ #include <cuda_runtime.h>
7
+ #include <cuda_fp16.h>
8
+ #include <cstdint>
9
+ #include <cstdio>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+
12
+ #include "q4_matrix.cuh"
13
+ #include "../tuning.h"
14
+
15
+ // Workaround for hipify_python using rocblas instead of hipblas.
16
+ #if defined(USE_ROCM)
17
+ #include <hipblas/hipblas.h>
18
+ #define rocblas_handle hipblasHandle_t
19
+ #endif
20
+
21
+ void q4_matmul_cuda
22
+ (
23
+ ExLlamaTuning* tuningParams,
24
+ const half* x,
25
+ const int x_height,
26
+ const Q4Matrix* w,
27
+ half* out,
28
+ bool no_zero = false,
29
+ cudaStream_t alt_stream = NULL
30
+ );
31
+
32
+ void q4_matmul_recons_cuda
33
+ (
34
+ ExLlamaTuning* tuningParams,
35
+ const half* x,
36
+ const int x_height,
37
+ Q4Matrix* w,
38
+ half* out,
39
+ const cublasHandle_t handle,
40
+ bool no_zero = false
41
+ );
42
+
43
+ #endif
AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matrix.cu ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #include "q4_matrix.cuh"
4
+ #include <vector>
5
+ #include "../util.cuh"
6
+ #include "../matrix.cuh"
7
+
8
+ using namespace std;
9
+
10
+ const int UNSHUF_BLOCKSIZE_X = 64;
11
+
12
+ const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
13
+ const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
14
+
15
+ vector<Q4Matrix*> g_q4_matrices;
16
+
17
+ void g_q4_keep_matrix(Q4Matrix* m)
18
+ {
19
+ g_q4_matrices.push_back(m);
20
+ }
21
+
22
+ void g_q4_free_matrices()
23
+ {
24
+ for (const auto& m : g_q4_matrices) delete m;
25
+ g_q4_matrices.clear();
26
+ }
27
+
28
+ Q4Matrix::Q4Matrix
29
+ (
30
+ const int _height,
31
+ const int _width,
32
+ const int _groups,
33
+
34
+ uint32_t* _qweight,
35
+ uint32_t* _qzeros,
36
+ half* _scales,
37
+ uint32_t* _g_idx,
38
+
39
+ const int _device
40
+ ) :
41
+ height(_height),
42
+ width(_width),
43
+ groups(_groups),
44
+ device(_device)
45
+ {
46
+ cudaSetDevice(device);
47
+
48
+ cuda_qweight = _qweight;
49
+ cuda_qzeros = _qzeros;
50
+ cuda_scales = _scales;
51
+
52
+ groupsize = height / groups;
53
+
54
+ if (_g_idx) make_sequential(_g_idx);
55
+ }
56
+
57
+ Q4Matrix::~Q4Matrix()
58
+ {
59
+ }
60
+
61
+ // Make sequential
62
+
63
+ __global__ void make_sequential_kernel
64
+ (
65
+ const uint32_t* __restrict__ w,
66
+ uint32_t* __restrict__ w_new,
67
+ const uint32_t* __restrict__ x_map,
68
+ const int w_height,
69
+ const int w_width
70
+ )
71
+ {
72
+ const uint64_t* w2 = (uint64_t*) w;
73
+ uint64_t* w_new2 = (uint64_t*) w_new;
74
+ int w2_stride = w_width >> 1;
75
+
76
+ int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
77
+ if (w2_column >= w2_stride) return;
78
+
79
+ int w_new2_row = blockIdx.y;
80
+
81
+ int x_map_idx = w_new2_row << 3;
82
+
83
+ uint64_t dst = 0;
84
+
85
+ #pragma unroll
86
+ for (int i = 0; i < 8; i++)
87
+ {
88
+ int source_row = x_map[x_map_idx++];
89
+
90
+ int w2_row = source_row >> 3;
91
+ int w2_subrow = source_row & 0x07;
92
+ int w2_row_shift = w2_subrow << 2;
93
+ int wnew2_row_shift = i << 2;
94
+
95
+ uint64_t src = w2[w2_row * w2_stride + w2_column];
96
+ src >>= w2_row_shift;
97
+ src &= 0x0000000f0000000f;
98
+ src <<= wnew2_row_shift;
99
+ dst |= src;
100
+ }
101
+
102
+ w_new2[w_new2_row * w2_stride + w2_column] = dst;
103
+ }
104
+
105
+ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
106
+ {
107
+ uint32_t* cuda_new_qweight = NULL;
108
+ cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
109
+ cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
110
+
111
+ uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
112
+ uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
113
+ uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
114
+
115
+ // Group histogram
116
+
117
+ for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
118
+
119
+ // Group map
120
+
121
+ for (int i = 0, acc = 0; i < groups; i++)
122
+ {
123
+ short tmp = cpu_g_idx_map[i];
124
+ cpu_g_idx_map[i] = acc;
125
+ acc += tmp;
126
+ }
127
+
128
+ // X map (inverse)
129
+
130
+ for (int row = 0; row < height; row++)
131
+ {
132
+ uint32_t target_group = cpu_g_idx[row];
133
+ uint32_t target_row = cpu_g_idx_map[target_group];
134
+ cpu_g_idx_map[target_group]++;
135
+ cpu_x_map_inv[row] = target_row;
136
+ }
137
+
138
+ // X map
139
+
140
+ for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
141
+
142
+ // Move to CUDA
143
+
144
+ cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
145
+
146
+ // Rearrange rows in w
147
+
148
+ dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
149
+ dim3 blocks
150
+ (
151
+ (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
152
+ height / 8,
153
+ 1
154
+ );
155
+
156
+ make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
157
+
158
+ // Replace qweights
159
+
160
+ cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
161
+
162
+ // Cleanup
163
+
164
+ cudaDeviceSynchronize();
165
+ cudaFree(cuda_new_qweight);
166
+ free(cpu_g_idx_map);
167
+ free(cpu_x_map);
168
+ free(cpu_x_map_inv);
169
+ }
170
+
171
+ __global__ void reconstruct_kernel
172
+ (
173
+ const uint32_t* __restrict__ w,
174
+ half* __restrict__ out, // (y)
175
+ const half* __restrict__ w_scales,
176
+ const uint32_t* __restrict__ w_zeros,
177
+ const int height,
178
+ const int width,
179
+ const int groupsize
180
+ )
181
+ {
182
+ // Start of block
183
+
184
+ int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
185
+ int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
186
+ if (column >= width) return;
187
+
188
+ // Views
189
+
190
+ MatrixView_q4_column w_(w, height, width);
191
+ MatrixView_half_rw out_(out, height, width);
192
+ MatrixView_half w_scales_(w_scales, height / groupsize, width);
193
+ MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
194
+
195
+ // Groupsize version
196
+
197
+ int group = row / groupsize;
198
+
199
+ half w_scale = w_scales_.item(group, column);
200
+
201
+ //
202
+ uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0f;
203
+
204
+ uint32_t w_read = w_.item_uint32_t(row, column);
205
+ half* out_ptr = out_.item_ptr(row, column);
206
+
207
+ #pragma unroll
208
+ for (int s = 0; s < 32; s += 4)
209
+ {
210
+ half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
211
+ *out_ptr = w_item; out_ptr += out_.width;
212
+ }
213
+ }
214
+
215
+ void Q4Matrix::reconstruct(half* out)
216
+ {
217
+ dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
218
+
219
+ dim3 blocks
220
+ (
221
+ (width + threads.x - 1) / threads.x,
222
+ (height / 8 + threads.y - 1) / threads.y,
223
+ 1
224
+ );
225
+
226
+ reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
227
+ }
AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matrix.cuh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #ifndef _q4_matrix_cuh
4
+ #define _q4_matrix_cuh
5
+
6
+ #include <cuda_runtime.h>
7
+ #include <cuda_fp16.h>
8
+ #include <cstdint>
9
+
10
+ class Q4Matrix
11
+ {
12
+ public:
13
+
14
+ int device;
15
+
16
+ int height;
17
+ int width;
18
+ int groups;
19
+ int groupsize;
20
+
21
+ uint32_t* cuda_qweight = NULL;
22
+ uint32_t* cuda_qzeros = NULL;
23
+ half* cuda_scales = NULL;
24
+ uint32_t* cuda_x_map = NULL;
25
+
26
+ Q4Matrix
27
+ (
28
+ const int _height,
29
+ const int _width,
30
+ const int _groups,
31
+
32
+ uint32_t* _qweight,
33
+ uint32_t* _qzeros,
34
+ half* _scales,
35
+ uint32_t* _g_idx,
36
+
37
+ const int _device
38
+ );
39
+
40
+ ~Q4Matrix();
41
+
42
+ void reconstruct(half* out);
43
+
44
+ private:
45
+
46
+ void make_sequential(const uint32_t* cpu_g_idx);
47
+
48
+ };
49
+
50
+ void g_q4_keep_matrix(Q4Matrix* m);
51
+ void g_q4_free_matrices();
52
+
53
+ #endif
AutoAWQ_kernels/awq_ext/exllama/exllama_ext.cpp ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #include <torch/extension.h>
4
+ #include <c10/cuda/CUDAGuard.h>
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <cuda_runtime.h>
7
+ #include <cuda_fp16.h>
8
+ #include <cstdint>
9
+ #include <cstdio>
10
+ #include "util.cuh"
11
+ #include "tuning.h"
12
+ #include "cuda_buffers.cuh"
13
+ #include "cuda_func/q4_matrix.cuh"
14
+ #include "cuda_func/q4_matmul.cuh"
15
+ #include "cuda_func/column_remap.cuh"
16
+
17
+ #include <typeinfo>
18
+ #include <limits>
19
+ #include <algorithm>
20
+
21
+ // Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
22
+ // minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
23
+ // exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
24
+
25
+ void check_cuda(cudaError_t ret)
26
+ {
27
+ switch (ret)
28
+ {
29
+ case cudaSuccess:
30
+ break;
31
+
32
+ case cudaUnspecified:
33
+ printf(" **** Unspecified error\n");
34
+ TORCH_CHECK(false, "CUDA error");
35
+ break;
36
+
37
+ default:
38
+ printf(" **** CUDA error\n"); \
39
+ printf(" **** %s\n", cudaGetErrorString(ret)); \
40
+ TORCH_CHECK(false, "CUDA error"); \
41
+ break;
42
+ }
43
+ }
44
+
45
+ // Some decluttering macros
46
+
47
+ #define STRINGIFY_(__x) #__x
48
+ #define STRINGIFY(__x) STRINGIFY_(__x)
49
+ #define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
50
+ #define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
51
+ #define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
52
+ #define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
53
+ #define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
54
+ #define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
55
+
56
+ #define TORCH_CHECK_DEVICE_INDEX(__index) \
57
+ do { \
58
+ TORCH_CHECK(__index >= 0, "no device index"); \
59
+ TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
60
+ } while(0)
61
+
62
+ #define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
63
+ do { \
64
+ TORCH_CHECK_DTYPE(__w, kInt); \
65
+ TORCH_CHECK_DTYPE(__w_scales, kHalf); \
66
+ TORCH_CHECK_DTYPE(__w_zeros, kInt); \
67
+ TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
68
+ TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
69
+ TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
70
+ TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
71
+ } while(0)
72
+
73
+ int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
74
+ {
75
+ int groupsize = w.size(0) * 8 / w_zeros.size(0);
76
+ TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
77
+ return groupsize;
78
+ }
79
+
80
+
81
+ // Tuning parameters
82
+
83
+ ExLlamaTuning tuningParams;
84
+
85
+ void set_tuning_params
86
+ (
87
+ int matmul_recons_thd,
88
+ bool matmul_fused_remap,
89
+ bool matmul_no_half2
90
+ )
91
+ {
92
+ tuningParams.matmul_recons_thd = matmul_recons_thd;
93
+ tuningParams.matmul_fused_remap = matmul_fused_remap;
94
+ tuningParams.matmul_no_half2 = matmul_no_half2;
95
+ }
96
+
97
+
98
+ // Release all unmanaged objects allocated by the extension
99
+
100
+ void cleanup()
101
+ {
102
+ cleanup_buffers_cuda();
103
+ g_q4_free_matrices();
104
+ }
105
+
106
+
107
+ // Prepare buffers for forward pass
108
+
109
+ void prepare_buffers
110
+ (
111
+ torch::Device device,
112
+ torch::Tensor temp_state,
113
+ torch::Tensor temp_dq
114
+ )
115
+ {
116
+ int device_index = device.index();
117
+ TORCH_CHECK_DEVICE_INDEX(device_index);
118
+ const at::cuda::OptionalCUDAGuard device_guard(device);
119
+ const long max_int = std::numeric_limits<int>::max();
120
+
121
+ prepare_buffers_cuda
122
+ (
123
+ device_index,
124
+ // buffer size used for sanity checks
125
+ std::clamp((long)temp_state.numel(), (long)0, max_int),
126
+ (half*) temp_state.data_ptr(),
127
+ (half*) temp_dq.data_ptr()
128
+ );
129
+ }
130
+
131
+
132
+ // Create Q4Matrix, return handle
133
+
134
+ uintptr_t make_q4
135
+ (
136
+ torch::Tensor qweight,
137
+ torch::Tensor qzeros,
138
+ torch::Tensor scales,
139
+ torch::Tensor g_idx,
140
+ int device
141
+ )
142
+ {
143
+ TORCH_CHECK_DTYPE(qweight, kInt);
144
+ TORCH_CHECK_DTYPE(qzeros, kInt);
145
+ TORCH_CHECK_DTYPE(scales, kHalf);
146
+ TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
147
+ TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
148
+ TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
149
+ TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
150
+
151
+ int width = qweight.size(1);
152
+ int height = qweight.size(0) * 8;
153
+ int groups = qzeros.size(0);
154
+
155
+ Q4Matrix* m = new Q4Matrix
156
+ (
157
+ height,
158
+ width,
159
+ groups,
160
+
161
+ (uint32_t*) qweight.data_ptr(),
162
+ (uint32_t*) qzeros.data_ptr(),
163
+ (half*) scales.data_ptr(),
164
+ g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
165
+
166
+ device
167
+ );
168
+
169
+ g_q4_keep_matrix(m);
170
+ return reinterpret_cast<uintptr_t> (m);
171
+ }
172
+
173
+
174
+ // Matmul half @ quant -> half
175
+
176
+ void q4_matmul
177
+ (
178
+ torch::Tensor x,
179
+ uintptr_t w,
180
+ torch::Tensor out
181
+ )
182
+ {
183
+ Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
184
+
185
+ TORCH_CHECK_DTYPE(x, kHalf);
186
+ TORCH_CHECK_DTYPE(out, kHalf);
187
+ TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
188
+ TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
189
+
190
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
191
+
192
+ int x_height = x.size(0);
193
+
194
+ if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
195
+ {
196
+ q4_matmul_cuda
197
+ (
198
+ &tuningParams,
199
+ (half*) x.data_ptr(),
200
+ x_height,
201
+ wm,
202
+ (half*) out.data_ptr()
203
+ );
204
+ }
205
+ else
206
+ {
207
+ q4_matmul_recons_cuda
208
+ (
209
+ &tuningParams,
210
+ (half*) x.data_ptr(),
211
+ x_height,
212
+ wm,
213
+ (half*) out.data_ptr(),
214
+ at::cuda::getCurrentCUDABlasHandle()
215
+ );
216
+ }
217
+ }
218
+
219
+
220
+ // Remap columns in half tensor
221
+
222
+ void column_remap
223
+ (
224
+ torch::Tensor x,
225
+ torch::Tensor x_new,
226
+ torch::Tensor x_map
227
+ )
228
+ {
229
+ TORCH_CHECK_DTYPE(x, kHalf);
230
+ TORCH_CHECK_DTYPE(x_new, kHalf);
231
+ TORCH_CHECK_DTYPE(x_map, kInt);
232
+ TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
233
+
234
+ int height = x.size(0);
235
+ int width = x.size(1);
236
+
237
+ TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
238
+
239
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
240
+
241
+ column_remap_cuda
242
+ (
243
+ (half*) x.data_ptr(),
244
+ (half*) x_new.data_ptr(),
245
+ height,
246
+ width,
247
+ (uint32_t*) x_map.data_ptr()
248
+ );
249
+ }
250
+
251
+
252
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
253
+ {
254
+ m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
255
+ m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
256
+ m.def("cleanup", &cleanup, "cleanup");
257
+ m.def("make_q4", &make_q4, "make_q4");
258
+ m.def("q4_matmul", &q4_matmul, "q4_matmul");
259
+ m.def("cleanup_buffers_cuda", &cleanup_buffers_cuda, "cleanup_buffers_cuda");
260
+ }
AutoAWQ_kernels/awq_ext/exllama/hip_compat.cuh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #ifndef _hip_compat_cuh
4
+ #define _hip_compat_cuh
5
+
6
+ // Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
7
+ __device__ __forceinline__ __half __compat_hrcp(__half x) {
8
+ return __half_raw{
9
+ static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
10
+ }
11
+
12
+ // ROCm 6.0 compatible from: /opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_fp16.h:1708
13
+ __device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
14
+ return _Float16_2{_Float16_2{static_cast<_Float16>(1.0f), static_cast<_Float16>(1.0f)} / x.data};
15
+ }
16
+
17
+ #define hrcp __compat_hrcp
18
+ #define h2rcp __compat_h2rcp
19
+
20
+ // Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
21
+ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
22
+ hipblasOperation_t transA,
23
+ hipblasOperation_t transB,
24
+ int m,
25
+ int n,
26
+ int k,
27
+ const half* alpha,
28
+ const half* AP,
29
+ int lda,
30
+ const half* BP,
31
+ int ldb,
32
+ const half* beta,
33
+ half* CP,
34
+ int ldc) {
35
+ return hipblasHgemm(handle, transA, transB, m, n, k,
36
+ reinterpret_cast<const hipblasHalf *>(alpha),
37
+ reinterpret_cast<const hipblasHalf *>(AP), lda,
38
+ reinterpret_cast<const hipblasHalf *>(BP), ldb,
39
+ reinterpret_cast<const hipblasHalf *>(beta),
40
+ reinterpret_cast<hipblasHalf *>(CP), ldc);
41
+ }
42
+ #define hipblasHgemm __compat_hipblasHgemm
43
+
44
+ // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
45
+ #define rocblas_handle hipblasHandle_t
46
+ #define rocblas_operation_none HIPBLAS_OP_N
47
+ #define rocblas_get_stream hipblasGetStream
48
+ #define rocblas_set_stream hipblasSetStream
49
+ #define rocblas_hgemm __compat_hipblasHgemm
50
+
51
+ #endif
AutoAWQ_kernels/awq_ext/exllama/matrix.cuh ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #ifndef _matrix_cuh
4
+ #define _matrix_cuh
5
+
6
+ #include <cuda_runtime.h>
7
+ #include <cuda_fp16.h>
8
+
9
+ class MatrixView_half
10
+ {
11
+ public:
12
+ const half* data;
13
+ const int height;
14
+ const int width;
15
+
16
+ __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
17
+ : data(data), height(height), width(width)
18
+ { }
19
+
20
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
21
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
22
+ __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
23
+ __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
24
+ };
25
+
26
+ class MatrixView_half_rw
27
+ {
28
+ public:
29
+ half* data;
30
+ const int height;
31
+ const int width;
32
+
33
+ __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
34
+ : data(data), height(height), width(width)
35
+ { }
36
+
37
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
38
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
39
+ __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
40
+ __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
41
+ __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
42
+ };
43
+
44
+ class MatrixView_q4_row
45
+ {
46
+ public:
47
+ const uint32_t* data;
48
+ const int height;
49
+ const int width;
50
+
51
+ __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
52
+ : data(data), height(height), width(width)
53
+ { }
54
+
55
+ __device__ __forceinline__ int item(int row, int column) const
56
+ {
57
+ int shift = (column & 0x07) * 4;
58
+ return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
59
+ }
60
+ };
61
+
62
+ class MatrixView_q4_column
63
+ {
64
+ public:
65
+ const uint32_t* data;
66
+ const int height;
67
+ const int width;
68
+
69
+ __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
70
+ : data(data), height(height), width(width)
71
+ { }
72
+
73
+ __device__ __forceinline__ int item(int row, int column) const
74
+ {
75
+ int shift = (row & 0x07) * 4;
76
+ return (data[row / 8 * width + column] >> shift) & 0x0f;
77
+ }
78
+
79
+ __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
80
+ __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
81
+ };
82
+
83
+ // TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
84
+
85
+ // Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
86
+
87
+ __device__ __forceinline__ half2 dot_product_8
88
+ (
89
+ const half2 acc,
90
+ MatrixView_half& h_,
91
+ const int h_row,
92
+ const int h_column, // divisible by 8
93
+ MatrixView_q4_column& v_,
94
+ const int v_row, // divisible by 8
95
+ const int v_column,
96
+ const half2 v_scale_2,
97
+ const uint32_t v_zero, // + 1 (!!)
98
+ const int count
99
+ )
100
+ {
101
+ const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
102
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
103
+ half2 result = acc;
104
+
105
+ for (int i = 0; i < count; i++)
106
+ {
107
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
108
+
109
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
110
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
111
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
112
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
113
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
114
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
115
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
116
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
117
+
118
+ half2 v_01 = __halves2half2(v_0, v_1);
119
+ half2 v_23 = __halves2half2(v_2, v_3);
120
+ half2 v_45 = __halves2half2(v_4, v_5);
121
+ half2 v_67 = __halves2half2(v_6, v_7);
122
+
123
+ // half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
124
+ // half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
125
+ // half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
126
+ // half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
127
+
128
+ half2 tmp = __hmul2(*h_ptr++, v_01);
129
+ tmp = __hfma2(*h_ptr++, v_23, tmp);
130
+ tmp = __hfma2(*h_ptr++, v_45, tmp);
131
+ tmp = __hfma2(*h_ptr++, v_67, tmp);
132
+ result = __hfma2(v_scale_2, tmp, result);
133
+ }
134
+
135
+ return result;
136
+ }
137
+
138
+ __device__ __forceinline__ half dot_product_8_h
139
+ (
140
+ const half acc,
141
+ MatrixView_half& h_,
142
+ const int h_row,
143
+ const int h_column, // divisible by 8
144
+ MatrixView_q4_column& v_,
145
+ const int v_row, // divisible by 8
146
+ const int v_column,
147
+ const half v_scale,
148
+ const uint32_t v_zero, // + 1 (!!)
149
+ const int count
150
+ )
151
+ {
152
+ const half* h_ptr = h_.item_ptr(h_row, h_column);
153
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
154
+ half result = acc;
155
+
156
+ for (int i = 0; i < count; i++)
157
+ {
158
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
159
+
160
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
161
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
162
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
163
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
164
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
165
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
166
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
167
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
168
+
169
+ half tmp = __hmul(*h_ptr++, v_0);
170
+ tmp = __hfma(*h_ptr++, v_1, tmp);
171
+ tmp = __hfma(*h_ptr++, v_2, tmp);
172
+ tmp = __hfma(*h_ptr++, v_3, tmp);
173
+ tmp = __hfma(*h_ptr++, v_4, tmp);
174
+ tmp = __hfma(*h_ptr++, v_5, tmp);
175
+ tmp = __hfma(*h_ptr++, v_6, tmp);
176
+ tmp = __hfma(*h_ptr++, v_7, tmp);
177
+ result = __hfma(v_scale, tmp, result);
178
+ }
179
+
180
+ return result;
181
+ }
182
+
183
+ // Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
184
+
185
+ __device__ __forceinline__ half2 dot_product_8_x_map
186
+ (
187
+ const half2 acc,
188
+ MatrixView_half& h_,
189
+ const int h_row,
190
+ const int h_column, // divisible by 8
191
+ MatrixView_q4_column& v_,
192
+ const int v_row, // divisible by 8
193
+ const int v_column,
194
+ const half2 v_scale_2,
195
+ const uint32_t v_zero, // + 1 (!!)
196
+ const int count,
197
+ const uint32_t* x_map
198
+ )
199
+ {
200
+ const half* h_ptr = h_.item_ptr(h_row, 0);
201
+ const uint32_t* x_map_ptr = x_map + h_column;
202
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
203
+ half2 result = acc;
204
+
205
+ for (int i = 0; i < count; i++)
206
+ {
207
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
208
+
209
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
210
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
211
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
212
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
213
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
214
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
215
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
216
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
217
+
218
+ half2 v_01 = __halves2half2(v_0, v_1);
219
+ half2 v_23 = __halves2half2(v_2, v_3);
220
+ half2 v_45 = __halves2half2(v_4, v_5);
221
+ half2 v_67 = __halves2half2(v_6, v_7);
222
+
223
+ half h_0 = h_ptr[*x_map_ptr++];
224
+ half h_1 = h_ptr[*x_map_ptr++];
225
+ half h_2 = h_ptr[*x_map_ptr++];
226
+ half h_3 = h_ptr[*x_map_ptr++];
227
+ half h_4 = h_ptr[*x_map_ptr++];
228
+ half h_5 = h_ptr[*x_map_ptr++];
229
+ half h_6 = h_ptr[*x_map_ptr++];
230
+ half h_7 = h_ptr[*x_map_ptr++];
231
+
232
+ half2 h_01 = __halves2half2(h_0, h_1);
233
+ half2 h_23 = __halves2half2(h_2, h_3);
234
+ half2 h_45 = __halves2half2(h_4, h_5);
235
+ half2 h_67 = __halves2half2(h_6, h_7);
236
+
237
+ half2 tmp = __hmul2(h_01, v_01);
238
+ tmp = __hfma2(h_23, v_23, tmp);
239
+ tmp = __hfma2(h_45, v_45, tmp);
240
+ tmp = __hfma2(h_67, v_67, tmp);
241
+ result = __hfma2(v_scale_2, tmp, result);
242
+ }
243
+
244
+ return result;
245
+ }
246
+
247
+ __device__ __forceinline__ half dot_product_8_x_map_h
248
+ (
249
+ const half acc,
250
+ MatrixView_half& h_,
251
+ const int h_row,
252
+ const int h_column, // divisible by 8
253
+ MatrixView_q4_column& v_,
254
+ const int v_row, // divisible by 8
255
+ const int v_column,
256
+ const half v_scale,
257
+ const uint32_t v_zero, // + 1 (!!)
258
+ const int count,
259
+ const uint32_t* x_map
260
+ )
261
+ {
262
+ const half* h_ptr = h_.item_ptr(h_row, 0);
263
+ const uint32_t* x_map_ptr = x_map + h_column;
264
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
265
+ half result = acc;
266
+
267
+ for (int i = 0; i < count; i++)
268
+ {
269
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
270
+
271
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
272
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
273
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
274
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
275
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
276
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
277
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
278
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
279
+
280
+ half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
281
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
282
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
283
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
284
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
285
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
286
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
287
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
288
+ result = __hfma(v_scale, tmp, result);
289
+ }
290
+
291
+ return result;
292
+ }
293
+
294
+ #endif
AutoAWQ_kernels/awq_ext/exllama/tuning.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #ifndef _tuning_h
4
+ #define _tuning_h
5
+
6
+ struct ExLlamaTuning
7
+ {
8
+ int matmul_recons_thd;
9
+ bool matmul_fused_remap;
10
+ bool matmul_no_half2;
11
+ };
12
+
13
+ #endif
AutoAWQ_kernels/awq_ext/exllama/util.cuh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
+
3
+ #ifndef _util_cuh
4
+ #define _util_cuh
5
+
6
+ #include <cuda_runtime.h>
7
+ #include <cuda_fp16.h>
8
+ #include <cstdint>
9
+ #include <cstdio>
10
+
11
+ #if defined(USE_ROCM)
12
+ #define cudaUnspecified hipErrorUnknown
13
+ #else
14
+ #define cudaUnspecified cudaErrorApiFailureBase
15
+ #endif
16
+
17
+ // React to failure on return code != cudaSuccess
18
+
19
+ #define _cuda_check(fn) \
20
+ do { \
21
+ {_cuda_err = fn;} \
22
+ if (_cuda_err != cudaSuccess) goto _cuda_fail; \
23
+ } while(false)
24
+
25
+ // React to failure on return code == 0
26
+
27
+ #define _alloc_check(fn) \
28
+ do { \
29
+ if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
30
+ else _cuda_err = cudaSuccess; \
31
+ } while(false)
32
+
33
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/config.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _config_h
2
+ #define _config_h
3
+
4
+ #define MAX_Q_GEMM_ROWS 50
5
+
6
+ #define QMODE_2BIT 1
7
+ #define QMODE_3BIT 1
8
+ #define QMODE_4BIT 1
9
+ #define QMODE_5BIT 1
10
+ #define QMODE_6BIT 0
11
+ #define QMODE_8BIT 0
12
+
13
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cpp/util.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _util_h
2
+ #define _util_h
3
+
4
+ #define DBGS(__x) printf("%s\n", __x)
5
+ #define DBGI(__x) printf("%s: %i\n", #__x, __x)
6
+ #define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
7
+ #define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
8
+ #define DBGF(__x) printf("%s: %f\n", #__x, __x)
9
+ #define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
10
+ #define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
11
+
12
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/compat.cuh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _compat_cuh
2
+ #define _compat_cuh
3
+
4
+ // atomicAdd for half types, to support CC < 7.x
5
+
6
+ __device__ __forceinline__ void atomicAdd_half(half* address, half val)
7
+ {
8
+ unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
9
+ unsigned int old = *address_as_ui;
10
+ unsigned int assumed;
11
+
12
+ do
13
+ {
14
+ assumed = old;
15
+ __half_raw hsum;
16
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
17
+ half tmpres = __hadd(hsum, val);
18
+ hsum = __half_raw(tmpres);
19
+ old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
20
+ old = atomicCAS(address_as_ui, assumed, old);
21
+ }
22
+ while (assumed != old);
23
+ }
24
+
25
+ // atomicAdd for half2 types
26
+
27
+ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
28
+ {
29
+ unsigned int* address_as_ui = (unsigned int*)address;
30
+ unsigned int old = *address_as_ui;
31
+ unsigned int assumed;
32
+ do
33
+ {
34
+ assumed = old;
35
+ half2 old_val = *((half2*)&old);
36
+ half2 new_val = __hadd2(old_val, val);
37
+ old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
38
+ }
39
+ while (assumed != old);
40
+ }
41
+
42
+ //
43
+
44
+ #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
45
+ #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
46
+
47
+ __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
48
+
49
+ #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
50
+ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
51
+ #endif
52
+
53
+ #endif
54
+ #endif
55
+
56
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/compat_gemm.cuh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _compat_gemm_cuh
2
+ #define _compat_gemm_cuh
3
+
4
+ #if defined(USE_ROCM)
5
+
6
+ // For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
7
+ // for symbols as hipblasHalf.
8
+ #include <hipblas/hipblas.h>
9
+
10
+ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
11
+ hipblasOperation_t transA,
12
+ hipblasOperation_t transB,
13
+ int m,
14
+ int n,
15
+ int k,
16
+ const half* alpha,
17
+ const half* AP,
18
+ int lda,
19
+ const half* BP,
20
+ int ldb,
21
+ const half* beta,
22
+ half* CP,
23
+ int ldc) {
24
+ return hipblasHgemm(handle, transA, transB, m, n, k,
25
+ reinterpret_cast<const hipblasHalf *>(alpha),
26
+ reinterpret_cast<const hipblasHalf *>(AP), lda,
27
+ reinterpret_cast<const hipblasHalf *>(BP), ldb,
28
+ reinterpret_cast<const hipblasHalf *>(beta),
29
+ reinterpret_cast<hipblasHalf *>(CP), ldc);
30
+ }
31
+ #define hipblasHgemm __compat_hipblasHgemm
32
+
33
+ // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
34
+ #define rocblas_operation_none HIPBLAS_OP_N
35
+ #define rocblas_hgemm __compat_hipblasHgemm
36
+ #endif
37
+
38
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/matrix_view.cuh ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _matrix_view_cuh
2
+ #define _matrix_view_cuh
3
+
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_fp16.h>
6
+
7
+ #include "quant/qdq_util.cuh"
8
+
9
+ class MatrixView_half
10
+ {
11
+ public:
12
+ const half* data;
13
+ const int height;
14
+ const int width;
15
+
16
+ __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
17
+ : data(data), height(height), width(width)
18
+ { }
19
+
20
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
21
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
22
+ __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
23
+ __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
24
+
25
+ __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
26
+ {
27
+ half2* ptr = (half2*) item_ptr(row, column);
28
+ half2 i01 = ptr[0];
29
+ half2 i23 = ptr[1];
30
+ items[0] = __low2half(i01);
31
+ items[1] = __high2half(i01);
32
+ items[2] = __low2half(i23);
33
+ items[3] = __high2half(i23);
34
+ }
35
+ __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
36
+ {
37
+ half2* ptr = (half2*)item_ptr(row, column);
38
+ half2 i01 = ptr[0];
39
+ half2 i23 = ptr[1];
40
+ items[0] = __half2float(__low2half(i01));
41
+ items[1] = __half2float(__high2half(i01));
42
+ items[2] = __half2float(__low2half(i23));
43
+ items[3] = __half2float(__high2half(i23));
44
+ }
45
+
46
+ __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
47
+ {
48
+ half2* ptr = (half2*)item_ptr(row, column);
49
+ half2 i01 = ptr[0];
50
+ half2 i23 = ptr[1];
51
+ items[0] = __half2half2(__low2half(i01));
52
+ items[1] = __half2half2(__high2half(i01));
53
+ items[2] = __half2half2(__low2half(i23));
54
+ items[3] = __half2half2(__high2half(i23));
55
+ }
56
+ };
57
+
58
+ class MatrixView_half_rw
59
+ {
60
+ public:
61
+ half* data;
62
+ const int height;
63
+ const int width;
64
+
65
+ __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
66
+ : data(data), height(height), width(width)
67
+ { }
68
+
69
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
70
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
71
+ __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
72
+ __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
73
+ __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
74
+
75
+ __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
76
+ {
77
+ half2 v01 = __halves2half2(v0, v1);
78
+ half2 v23 = __halves2half2(v2, v3);
79
+ half2* ptr = (half2*) item_ptr(row, column);
80
+ ptr[0] = v01;
81
+ ptr[1] = v23;
82
+ }
83
+ };
84
+
85
+ class MatrixView_q4_row
86
+ {
87
+ public:
88
+ const uint32_t* data;
89
+ const int height;
90
+ const int width;
91
+
92
+ __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
93
+ : data(data), height(height), width(width)
94
+ { }
95
+
96
+ __device__ __forceinline__ int item(int row, int column) const
97
+ {
98
+ int shift = (column & 0x07) * 4;
99
+ return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
100
+ }
101
+
102
+ __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
103
+ {
104
+ int shift = (column & 0x07) * 4;
105
+ uint32_t d = data[row * width / 8 + column / 8] >> shift;
106
+ items[0] = d & 0x0f;
107
+ items[1] = (d >> 4) & 0x0f;
108
+ }
109
+
110
+ __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
111
+ {
112
+ int shift = (column & 0x07) * 4;
113
+ uint32_t d = data[row * width / 8 + column / 8] >> shift;
114
+ items[0] = d & 0x0f;
115
+ items[1] = (d >> 4) & 0x0f;
116
+ items[2] = (d >> 8) & 0x0f;
117
+ items[3] = (d >> 12) & 0x0f;
118
+ }
119
+ };
120
+
121
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm.cu ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "q_gemm.cuh"
2
+ #include "util.cuh"
3
+ #include "matrix_view.cuh"
4
+ #include "../config.h"
5
+
6
+ #include "quant/qdq_2.cuh"
7
+ #include "quant/qdq_3.cuh"
8
+ #include "quant/qdq_4.cuh"
9
+ #include "quant/qdq_5.cuh"
10
+ #include "quant/qdq_6.cuh"
11
+ #include "quant/qdq_8.cuh"
12
+
13
+ #define BLOCK_KN_SIZE 128
14
+ #define BLOCK_M_SIZE_MAX 8
15
+ #define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
16
+ #define CLEAR_N_SIZE 256
17
+
18
+ #include "q_gemm_kernel.cuh"
19
+ #include "q_gemm_kernel_gptq.cuh"
20
+
21
+ #include "compat_gemm.cuh"
22
+
23
+ void gemm_half_q_half_cuda_part
24
+ (
25
+ const half* a,
26
+ QMatrix* b,
27
+ half* c,
28
+ int size_m,
29
+ int size_n,
30
+ int size_k,
31
+ int m_count,
32
+ bool clear
33
+ )
34
+ {
35
+ if (!b->is_gptq)
36
+ {
37
+ dim3 blockDim, gridDim;
38
+ blockDim.x = BLOCK_KN_SIZE;
39
+ blockDim.y = 1;
40
+ blockDim.z = 1;
41
+ gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
42
+ gridDim.y = DIVIDE(size_m, m_count);
43
+ gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
44
+
45
+ fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
46
+
47
+ kernel<<<gridDim, blockDim>>>
48
+ (
49
+ a,
50
+ b->cuda_q_weight,
51
+ b->cuda_q_scale,
52
+ b->cuda_q_scale_max,
53
+ c,
54
+ size_m,
55
+ size_n,
56
+ size_k,
57
+ b->groups,
58
+ b->groupsize,
59
+ b->cuda_q_perm,
60
+ b->rows_8,
61
+ b->rows_6,
62
+ b->rows_5,
63
+ b->rows_4,
64
+ b->rows_3,
65
+ b->rows_2,
66
+ clear
67
+ );
68
+ }
69
+ else
70
+ {
71
+ dim3 blockDim, gridDim;
72
+ blockDim.x = BLOCK_KN_SIZE;
73
+ blockDim.y = 1;
74
+ blockDim.z = 1;
75
+ gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
76
+ gridDim.y = DIVIDE(size_m, m_count);
77
+ gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
78
+
79
+ fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
80
+
81
+ // DBGX((uint64_t) b->cuda_q_perm);
82
+ // DBGI(b->rows_4);
83
+ // DBGI(b->height);
84
+
85
+ kernel<<<gridDim, blockDim>>>
86
+ (
87
+ a,
88
+ b->cuda_q_weight,
89
+ b->cuda_gptq_qzeros,
90
+ b->cuda_gptq_scales,
91
+ c,
92
+ size_m,
93
+ size_n,
94
+ size_k,
95
+ b->groups,
96
+ b->groupsize,
97
+ b->cuda_q_perm,
98
+ b->rows_4,
99
+ clear
100
+ );
101
+ }
102
+ }
103
+
104
+ void gemm_half_q_half_cuda
105
+ (
106
+ cublasHandle_t cublas_handle,
107
+ const half* a,
108
+ QMatrix* b,
109
+ half* c,
110
+ int size_m,
111
+ int size_n,
112
+ int size_k,
113
+ bool clear,
114
+ half* temp_dq,
115
+ bool force_cuda
116
+ )
117
+ {
118
+ if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
119
+ {
120
+ //printf("cublas\n");
121
+
122
+ // Reconstruct FP16 matrix, then cuBLAS
123
+
124
+ if (!temp_dq) temp_dq = b->temp_dq;
125
+ b->reconstruct(temp_dq);
126
+
127
+ //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
128
+
129
+ const half alpha = __float2half(1.0f);
130
+ const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
131
+ cublasHgemm(cublas_handle,
132
+ CUBLAS_OP_N,
133
+ CUBLAS_OP_N,
134
+ size_n, size_m, size_k,
135
+ &alpha, temp_dq, size_n,
136
+ a, size_k,
137
+ &beta, c, size_n);
138
+
139
+ //const float alpha = 1.0f;
140
+ //const float beta = clear ? 0.0f : 1.0f;
141
+ //cublasSgemmEx(cublas_handle,
142
+ // CUBLAS_OP_N,
143
+ // CUBLAS_OP_N,
144
+ // size_n, size_m, size_k,
145
+ // &alpha, temp_dq, CUDA_R_16F, size_n,
146
+ // a, CUDA_R_16F, size_k,
147
+ // &beta, c, CUDA_R_16F, size_n);
148
+
149
+ //const float alpha = 1.0f;
150
+ //const float beta = clear ? 0.0f : 1.0f;
151
+ //cublasGemmEx(cublas_handle,
152
+ // CUBLAS_OP_N, CUBLAS_OP_N,
153
+ // size_n, size_m, size_k,
154
+ // &alpha, temp_dq, CUDA_R_16F, size_n,
155
+ // a, CUDA_R_16F, size_k,
156
+ // &beta, c, CUDA_R_16F, size_n,
157
+ // CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
158
+ }
159
+ else
160
+ {
161
+ //printf("cuda\n");
162
+
163
+ // Quantized matmul
164
+
165
+ //if (clear) clear_tensor_cuda(c, size_m, size_n);
166
+
167
+ int max_chunks = size_m / BLOCK_M_SIZE_MAX;
168
+ int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
169
+ int last_chunk_size = size_m - last_chunk;
170
+
171
+ if (max_chunks)
172
+ {
173
+ gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
174
+ }
175
+
176
+ if (last_chunk_size)
177
+ {
178
+ gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
179
+ }
180
+ }
181
+ }
182
+
183
+ __global__ void clear_kernel
184
+ (
185
+ half* __restrict__ c,
186
+ const int size_m,
187
+ const int size_n
188
+ )
189
+ {
190
+ int m = blockIdx.y;
191
+ int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
192
+ if (n >= size_n) return;
193
+ int4* c_ptr = (int4*)(c + m * size_n + n);
194
+ *c_ptr = {};
195
+ }
196
+
197
+ void clear_tensor_cuda
198
+ (
199
+ half* c,
200
+ int size_m,
201
+ int size_n
202
+ )
203
+ {
204
+ return;
205
+ dim3 blockDim, gridDim;
206
+ blockDim.x = CLEAR_N_SIZE;
207
+ blockDim.y = 1;
208
+ gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
209
+ gridDim.y = size_m;
210
+ clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
211
+ }
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm.cuh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _q_gemm_cuh
2
+ #define _q_gemm_cuh
3
+
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_fp16.h>
6
+ #include <cstdint>
7
+ #include <cstdio>
8
+ #include <ATen/cuda/CUDAContext.h>
9
+
10
+ #include "q_matrix.cuh"
11
+
12
+ void gemm_half_q_half_cuda
13
+ (
14
+ cublasHandle_t cublas_handle,
15
+ const half* a,
16
+ QMatrix* b,
17
+ half* c,
18
+ int size_m,
19
+ int size_n,
20
+ int size_k,
21
+ bool clear = false,
22
+ half* reconstruct = NULL,
23
+ bool force_cuda = false
24
+ );
25
+
26
+ void clear_tensor_cuda
27
+ (
28
+ half* c,
29
+ int size_m,
30
+ int size_n
31
+ );
32
+
33
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm_kernel.cuh ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "compat.cuh"
2
+
3
+ #include <cuda_runtime.h>
4
+ #include <cuda_fp16.h>
5
+
6
+ __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
7
+ {
8
+ half2 result = {};
9
+ const half2* a2_ptr = (const half2*)a_ptr;
10
+ #pragma unroll
11
+ for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
12
+ return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
13
+ }
14
+
15
+ __forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
16
+ {
17
+ half2 result = {};
18
+ const half2* a2_ptr = (const half2*)a_ptr;
19
+ #pragma unroll
20
+ for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
21
+ return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
22
+ }
23
+
24
+ __forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
25
+ {
26
+ half2 result = {};
27
+ const half2* a2_ptr = (const half2*)a_ptr;
28
+ #pragma unroll
29
+ for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
30
+ return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
31
+ }
32
+
33
+ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
34
+ {
35
+ half2 result = {};
36
+ const half2* a2_ptr = (const half2*)a_ptr;
37
+ #pragma unroll
38
+ for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
39
+ float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
40
+ return fma(result_f, qs_f, g_result);
41
+ }
42
+
43
+ __forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
44
+ {
45
+ half2 result = {};
46
+ const half2* a2_ptr = (const half2*)a_ptr;
47
+ #pragma unroll
48
+ for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
49
+ float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
50
+ return fma(result_f, qs_f, g_result);
51
+ }
52
+
53
+ __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
54
+ {
55
+ half2 result = {};
56
+ const half2* a2_ptr = (const half2*)a_ptr;
57
+ #pragma unroll
58
+ for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
59
+ float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
60
+ return fma(result_f, qs_f, g_result);
61
+ }
62
+
63
+
64
+
65
+ typedef void (*fp_gemm_half_q_half_kernel)
66
+ (
67
+ const half*,
68
+ const uint32_t*,
69
+ const uint32_t*,
70
+ const half*,
71
+ half*,
72
+ const int,
73
+ const int,
74
+ const int,
75
+ const int,
76
+ const int,
77
+ const uint16_t*,
78
+ const int,
79
+ const int,
80
+ const int,
81
+ const int,
82
+ const int,
83
+ const int,
84
+ const bool
85
+ );
86
+
87
+ template <bool first_block, int m_count>
88
+ __global__ void gemm_half_q_half_kernel
89
+ (
90
+ const half* __restrict__ a,
91
+ const uint32_t* __restrict__ b_q_weight,
92
+ const uint32_t* __restrict__ b_q_scale,
93
+ const half* __restrict__ b_q_scale_max,
94
+ half* __restrict__ c,
95
+ const int size_m,
96
+ const int size_n,
97
+ const int size_k,
98
+ const int groups,
99
+ const int groupsize,
100
+ const uint16_t* __restrict__ b_q_perm,
101
+ const int rows_8,
102
+ const int rows_6,
103
+ const int rows_5,
104
+ const int rows_4,
105
+ const int rows_3,
106
+ const int rows_2,
107
+ const bool clear
108
+ )
109
+ {
110
+ MatrixView_half a_(a, size_m, size_k);
111
+ MatrixView_half_rw c_(c, size_m, size_n);
112
+ MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
113
+
114
+ int t = threadIdx.x;
115
+
116
+ // Block
117
+
118
+ int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
119
+ int offset_m = blockIdx.y * m_count;
120
+ int offset_k = blockIdx.z * BLOCK_KN_SIZE;
121
+
122
+ int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
123
+ int end_m = min(offset_m + m_count, size_m);
124
+ int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
125
+ int n = offset_n + t * 4;
126
+
127
+ // Preload block_a
128
+
129
+ __shared__ half block_a[m_count][BLOCK_KN_SIZE];
130
+
131
+ if (offset_k + t < end_k)
132
+ {
133
+ for (int m = 0; m < m_count; ++m)
134
+ {
135
+ const half* a_ptr = a_.item_ptr(offset_m + m, 0);
136
+ half* block_a_ptr = block_a[m];
137
+ half a0 = a_ptr[b_q_perm[offset_k + t]];
138
+ block_a_ptr[t] = a0;
139
+ }
140
+ }
141
+
142
+ // Clear
143
+
144
+ if (n >= size_n) return;
145
+
146
+ if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
147
+ {
148
+ for (int m = 0; m < m_count; m++)
149
+ *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
150
+ }
151
+
152
+ __syncthreads();
153
+
154
+ // Find initial group
155
+
156
+ int group = offset_k / groupsize;
157
+
158
+ // Preload scales
159
+
160
+ float scales[MAX_GROUPS_IN_BLOCK][4];
161
+
162
+ int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
163
+ for (int g = 0; g < groups_in_block; g++)
164
+ {
165
+ int qscales[4];
166
+ b_q_scale_.item4(qscales, group + g, n);
167
+ qscales[0]++;
168
+ qscales[1]++;
169
+ qscales[2]++;
170
+ qscales[3]++;
171
+ float maxscale = __half2float(b_q_scale_max[group + g]);
172
+ scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
173
+ scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
174
+ scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
175
+ scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
176
+ }
177
+
178
+ // a, b offset
179
+
180
+ int pre_rows_8 = min(rows_8, offset_k);
181
+ int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
182
+ int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
183
+ int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
184
+ int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
185
+ int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
186
+ int qk = 0;
187
+ qk += pre_rows_8 / 32 * 8;
188
+ qk += pre_rows_6 / 32 * 6;
189
+ qk += pre_rows_5 / 32 * 5;
190
+ qk += pre_rows_4 / 32 * 4;
191
+ qk += pre_rows_3 / 32 * 3;
192
+ qk += pre_rows_2 / 32 * 2;
193
+
194
+ const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
195
+ const half* a_ptr = &block_a[0][0];
196
+ int a_stride = BLOCK_KN_SIZE;
197
+
198
+ // Initial group
199
+
200
+ int scales_idx = 0;
201
+ float qs_f0 = scales[scales_idx][0];
202
+ float qs_f1 = scales[scales_idx][1];
203
+ float qs_f2 = scales[scales_idx][2];
204
+ float qs_f3 = scales[scales_idx][3];
205
+ int nextgroup = offset_k + groupsize;
206
+
207
+ // Column result
208
+
209
+ float block_c[m_count][4] = {};
210
+
211
+ // Dequantize groups
212
+
213
+ int k = offset_k;
214
+
215
+ while (k < rows_8 && k < end_k)
216
+ {
217
+ if (k == nextgroup)
218
+ {
219
+ group++;
220
+ scales_idx++;
221
+ qs_f0 = scales[scales_idx][0];
222
+ qs_f1 = scales[scales_idx][1];
223
+ qs_f2 = scales[scales_idx][2];
224
+ qs_f3 = scales[scales_idx][3];
225
+ nextgroup += groupsize;
226
+ }
227
+
228
+ #pragma unroll
229
+ for (int j = 0; j < 4; j++)
230
+ {
231
+ int4 load_int4[2];
232
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
233
+ load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
234
+
235
+ half2 dq[4][4];
236
+ dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
237
+ dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
238
+ dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
239
+ dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
240
+
241
+ for (int m = 0; m < m_count; m++)
242
+ {
243
+ block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
244
+ block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
245
+ block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
246
+ block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
247
+ }
248
+ a_ptr += 8;
249
+ }
250
+ k += 32;
251
+ }
252
+
253
+ while (k < rows_6 && k < end_k)
254
+ {
255
+ if (k == nextgroup)
256
+ {
257
+ group++;
258
+ scales_idx++;
259
+ qs_f0 = scales[scales_idx][0];
260
+ qs_f1 = scales[scales_idx][1];
261
+ qs_f2 = scales[scales_idx][2];
262
+ qs_f3 = scales[scales_idx][3];
263
+ nextgroup += groupsize;
264
+ }
265
+
266
+ #pragma unroll
267
+ for (int j = 0; j < 2; j++)
268
+ {
269
+ int4 load_int4[3];
270
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
271
+ load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
272
+ load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
273
+
274
+ half2 dq[4][8];
275
+ dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
276
+ dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
277
+ dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
278
+ dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
279
+
280
+ for (int m = 0; m < m_count; m++)
281
+ {
282
+ block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
283
+ block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
284
+ block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
285
+ block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
286
+ }
287
+ a_ptr += 16;
288
+ }
289
+ k += 32;
290
+ }
291
+
292
+ while (k < rows_5 && k < end_k)
293
+ {
294
+ if (k == nextgroup)
295
+ {
296
+ group++;
297
+ scales_idx++;
298
+ qs_f0 = scales[scales_idx][0];
299
+ qs_f1 = scales[scales_idx][1];
300
+ qs_f2 = scales[scales_idx][2];
301
+ qs_f3 = scales[scales_idx][3];
302
+ nextgroup += groupsize;
303
+ }
304
+
305
+ #pragma unroll
306
+ for (int j = 0; j < 1; j++)
307
+ {
308
+ int4 load_int4[5];
309
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
310
+ load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
311
+ load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
312
+ load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
313
+ load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
314
+
315
+ half2 dq[4][16];
316
+ dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
317
+ dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
318
+ dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
319
+ dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
320
+
321
+ for (int m = 0; m < m_count; m++)
322
+ {
323
+ block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
324
+ block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
325
+ block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
326
+ block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
327
+ }
328
+ a_ptr += 32;
329
+ }
330
+
331
+ k += 32;
332
+ }
333
+
334
+ while (k < rows_4 && k < end_k)
335
+ {
336
+ if (k == nextgroup)
337
+ {
338
+ group++;
339
+ scales_idx++;
340
+ qs_f0 = scales[scales_idx][0];
341
+ qs_f1 = scales[scales_idx][1];
342
+ qs_f2 = scales[scales_idx][2];
343
+ qs_f3 = scales[scales_idx][3];
344
+ nextgroup += groupsize;
345
+ }
346
+
347
+ #pragma unroll
348
+ for (int j = 0; j < 4; j++)
349
+ {
350
+ int4 load_int4[1];
351
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
352
+
353
+ half2 dq[4][4];
354
+ dequant_4bit_8(load_int4[0].x, dq[0], size_n);
355
+ dequant_4bit_8(load_int4[0].y, dq[1], size_n);
356
+ dequant_4bit_8(load_int4[0].z, dq[2], size_n);
357
+ dequant_4bit_8(load_int4[0].w, dq[3], size_n);
358
+
359
+ for (int m = 0; m < m_count; m++)
360
+ {
361
+ block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
362
+ block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
363
+ block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
364
+ block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
365
+ }
366
+ a_ptr += 8;
367
+ }
368
+ k += 32;
369
+ }
370
+
371
+ while (k < rows_3 && k < end_k)
372
+ {
373
+ if (k == nextgroup)
374
+ {
375
+ group++;
376
+ scales_idx++;
377
+ qs_f0 = scales[scales_idx][0];
378
+ qs_f1 = scales[scales_idx][1];
379
+ qs_f2 = scales[scales_idx][2];
380
+ qs_f3 = scales[scales_idx][3];
381
+ nextgroup += groupsize;
382
+ }
383
+
384
+ #pragma unroll
385
+ for (int j = 0; j < 1; j++)
386
+ {
387
+ int4 load_int4[3];
388
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
389
+ load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
390
+ load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
391
+
392
+ half2 dq[4][16];
393
+ dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
394
+ dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
395
+ dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
396
+ dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
397
+
398
+ for (int m = 0; m < m_count; m++)
399
+ {
400
+ block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
401
+ block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
402
+ block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
403
+ block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
404
+ }
405
+ a_ptr += 32;
406
+ }
407
+ k += 32;
408
+ }
409
+
410
+ while (k < rows_2 && k < end_k)
411
+ {
412
+ if (k == nextgroup)
413
+ {
414
+ group++;
415
+ scales_idx++;
416
+ qs_f0 = scales[scales_idx][0];
417
+ qs_f1 = scales[scales_idx][1];
418
+ qs_f2 = scales[scales_idx][2];
419
+ qs_f3 = scales[scales_idx][3];
420
+ nextgroup += groupsize;
421
+ }
422
+
423
+ #pragma unroll
424
+ for (int j = 0; j < 2; j++)
425
+ {
426
+ int4 load_int4[1];
427
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
428
+
429
+ half2 dq[4][8];
430
+ dequant_2bit_16(load_int4[0].x, dq[0], size_n);
431
+ dequant_2bit_16(load_int4[0].y, dq[1], size_n);
432
+ dequant_2bit_16(load_int4[0].z, dq[2], size_n);
433
+ dequant_2bit_16(load_int4[0].w, dq[3], size_n);
434
+
435
+ for (int m = 0; m < m_count; m++)
436
+ {
437
+ block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
438
+ block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
439
+ block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
440
+ block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
441
+ }
442
+
443
+ a_ptr += 16;
444
+ }
445
+ k += 32;
446
+ }
447
+
448
+ // Accumulate column sums in c
449
+
450
+ for (int m = 0; m < m_count; m++)
451
+ {
452
+ half2* out = (half2*)c_.item_ptr(offset_m + m, n);
453
+ half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
454
+ half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
455
+ atomicAdd(out , result01);
456
+ atomicAdd(out + 1, result23);
457
+ }
458
+ }
459
+
460
+ fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
461
+ {
462
+ #if BLOCK_M_SIZE_MAX >= 1
463
+ if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
464
+ #endif
465
+ #if BLOCK_M_SIZE_MAX >= 2
466
+ if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
467
+ #endif
468
+ #if BLOCK_M_SIZE_MAX >= 3
469
+ if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
470
+ #endif
471
+ #if BLOCK_M_SIZE_MAX >= 4
472
+ if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
473
+ #endif
474
+ #if BLOCK_M_SIZE_MAX >= 5
475
+ if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
476
+ #endif
477
+ #if BLOCK_M_SIZE_MAX >= 6
478
+ if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
479
+ #endif
480
+ #if BLOCK_M_SIZE_MAX >= 7
481
+ if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
482
+ #endif
483
+ #if BLOCK_M_SIZE_MAX >= 8
484
+ if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
485
+ #endif
486
+ return NULL;
487
+ }
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm_kernel_gptq.cuh ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "compat.cuh"
2
+
3
+ __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
4
+ {
5
+ half2 result = {};
6
+ const half2* a2_ptr = (const half2*)a_ptr;
7
+ #pragma unroll
8
+ for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
9
+ return __hadd2(result, g_result);
10
+ }
11
+
12
+ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
13
+ {
14
+ half2 result = {};
15
+ const half2* a2_ptr = (const half2*)a_ptr;
16
+ #pragma unroll
17
+ for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
18
+ return __half2float(__low2half(result)) + __half2float(__high2half(result));
19
+ }
20
+
21
+ typedef void (*fp_gemm_half_q_half_gptq_kernel)
22
+ (
23
+ const half*,
24
+ const uint32_t*,
25
+ const uint32_t*,
26
+ const half*,
27
+ half*,
28
+ const int,
29
+ const int,
30
+ const int,
31
+ const int,
32
+ const int,
33
+ const uint16_t*,
34
+ const int,
35
+ const bool
36
+ );
37
+
38
+ template <bool first_block, int m_count>
39
+ __global__ void gemm_half_q_half_gptq_kernel
40
+ (
41
+ const half* __restrict__ a,
42
+ const uint32_t* __restrict__ b_q_weight,
43
+ const uint32_t* __restrict__ b_gptq_qzeros,
44
+ const half* __restrict__ b_gptq_scales,
45
+ half* __restrict__ c,
46
+ const int size_m,
47
+ const int size_n,
48
+ const int size_k,
49
+ const int groups,
50
+ const int groupsize,
51
+ const uint16_t* __restrict__ b_q_perm,
52
+ const int rows_4,
53
+ const bool clear
54
+ )
55
+ {
56
+ MatrixView_half a_(a, size_m, size_k);
57
+ MatrixView_half_rw c_(c, size_m, size_n);
58
+ MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
59
+ MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
60
+
61
+ int t = threadIdx.x;
62
+
63
+ // Block
64
+
65
+ int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
66
+ int offset_m = blockIdx.y * m_count;
67
+ int offset_k = blockIdx.z * BLOCK_KN_SIZE;
68
+
69
+ int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
70
+ int end_m = min(offset_m + m_count, size_m);
71
+ int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
72
+
73
+ int n = offset_n + t * 4;
74
+
75
+ // Preload block_a
76
+
77
+ __shared__ half block_a[m_count][BLOCK_KN_SIZE];
78
+
79
+ if (offset_k + t < end_k)
80
+ {
81
+ for (int m = 0; m < m_count; ++m)
82
+ {
83
+ const half* a_ptr = a_.item_ptr(offset_m + m, 0);
84
+ half* block_a_ptr = block_a[m];
85
+
86
+ half a0;
87
+ if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
88
+ else a0 = a_ptr[offset_k + t];
89
+ block_a_ptr[t] = a0;
90
+ }
91
+ }
92
+
93
+ // Zero output
94
+
95
+ if (n >= size_n) return;
96
+
97
+ if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
98
+ {
99
+ for (int m = 0; m < m_count; m++)
100
+ *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
101
+ }
102
+
103
+ __syncthreads();
104
+
105
+ // Find initial group
106
+
107
+ int group = offset_k / groupsize;
108
+ int nextgroup = offset_k + groupsize;
109
+
110
+ // a, b offset
111
+
112
+ int qk = offset_k / (32 / 4);
113
+
114
+ const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
115
+ const half* a_ptr = &block_a[0][0];
116
+ int a_stride = BLOCK_KN_SIZE;
117
+
118
+ // Initial group
119
+
120
+ int zeros[4];
121
+ float scales[4];
122
+ half2 z1z16[4][2];
123
+ half2 y1y16[4][2];
124
+ b_gptq_qzeros_.item4(zeros, group, n);
125
+ b_gptq_scales_.item4_f(scales, group, n);
126
+ dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
127
+ dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
128
+ dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
129
+ dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
130
+
131
+ // __syncthreads();
132
+
133
+ // Column result
134
+
135
+ float block_c[m_count][4] = {};
136
+
137
+ // Dequantize and multiply
138
+
139
+ int k = offset_k;
140
+ while (k < end_k)
141
+ {
142
+ if (k == nextgroup)
143
+ {
144
+ group++;
145
+ nextgroup += groupsize;
146
+ b_gptq_qzeros_.item4(zeros, group, n);
147
+ b_gptq_scales_.item4_f(scales, group, n);
148
+ dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
149
+ dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
150
+ dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
151
+ dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
152
+ }
153
+
154
+ #pragma unroll
155
+ for (int j = 0; j < 4; j++)
156
+ {
157
+ const int4* b_ptr4 = (int4*) b_ptr;
158
+ int4 load_int4 = *b_ptr4;
159
+
160
+ half2 dq[4][4];
161
+ dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
162
+ dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
163
+ dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
164
+ dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
165
+
166
+ #pragma unroll
167
+ for (int m = 0; m < m_count; m++)
168
+ {
169
+ block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
170
+ block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
171
+ block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
172
+ block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
173
+ }
174
+
175
+ b_ptr += size_n;
176
+ a_ptr += 8;
177
+ }
178
+
179
+ k += 32;
180
+ }
181
+
182
+ for (int m = 0; m < m_count; m++)
183
+ {
184
+ half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
185
+ half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
186
+ half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
187
+ atomicAdd(out , result01);
188
+ atomicAdd(out + 1, result23);
189
+ }
190
+ }
191
+
192
+ fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
193
+ {
194
+ #if BLOCK_M_SIZE_MAX >= 1
195
+ if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
196
+ #endif
197
+ #if BLOCK_M_SIZE_MAX >= 2
198
+ if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
199
+ #endif
200
+ #if BLOCK_M_SIZE_MAX >= 3
201
+ if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
202
+ #endif
203
+ #if BLOCK_M_SIZE_MAX >= 4
204
+ if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
205
+ #endif
206
+ #if BLOCK_M_SIZE_MAX >= 5
207
+ if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
208
+ #endif
209
+ #if BLOCK_M_SIZE_MAX >= 6
210
+ if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
211
+ #endif
212
+ #if BLOCK_M_SIZE_MAX >= 7
213
+ if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
214
+ #endif
215
+ #if BLOCK_M_SIZE_MAX >= 8
216
+ if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
217
+ #endif
218
+ return NULL;
219
+ }
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_matrix.cu ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "q_matrix.cuh"
2
+ #include "matrix_view.cuh"
3
+ #include "util.cuh"
4
+
5
+ #include "quant/qdq_2.cuh"
6
+ #include "quant/qdq_3.cuh"
7
+ #include "quant/qdq_4.cuh"
8
+ #include "quant/qdq_5.cuh"
9
+ #include "quant/qdq_6.cuh"
10
+ #include "quant/qdq_8.cuh"
11
+
12
+ #define BLOCK_KN_SIZE 128
13
+
14
+ #define THREADS_X 32
15
+ #define THREADS_Y 32
16
+
17
+ // Shuffle quantized data on load
18
+
19
+ __global__ void shuffle_kernel
20
+ (
21
+ uint32_t* __restrict__ b_q_weight,
22
+ const int size_k,
23
+ const int size_n,
24
+ const int rows_8,
25
+ const int rows_6,
26
+ const int rows_5,
27
+ const int rows_4,
28
+ const int rows_3,
29
+ const int rows_2
30
+ )
31
+ {
32
+ int n = blockIdx.x * THREADS_X + threadIdx.x;
33
+ if (n >= size_n) return;
34
+ int k = 0;
35
+ uint32_t* b_ptr = b_q_weight + n;
36
+ while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
37
+ while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
38
+ while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
39
+ while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
40
+ while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
41
+ while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
42
+ }
43
+
44
+
45
+ // QMatrix constructor
46
+
47
+ QMatrix::QMatrix
48
+ (
49
+ const int _device,
50
+ const int _height,
51
+ const int _width,
52
+ const int _groups,
53
+
54
+ uint32_t* _q_weight,
55
+ uint16_t* _q_perm,
56
+ uint16_t* _q_invperm,
57
+ uint32_t* _q_scale,
58
+ half* _q_scale_max,
59
+ uint16_t* _q_groups,
60
+
61
+ uint32_t* _gptq_qzeros,
62
+ half* _gptq_scales,
63
+ uint32_t* _gptq_g_idx,
64
+
65
+ half* _temp_dq
66
+ ) :
67
+ device(_device),
68
+ height(_height),
69
+ width(_width),
70
+ groups(_groups),
71
+ temp_dq(_temp_dq)
72
+ {
73
+ cudaSetDevice(device);
74
+
75
+ failed = false;
76
+
77
+ cuda_q_weight = _q_weight;
78
+ cuda_q_perm = _q_perm;
79
+ cuda_q_invperm = _q_invperm;
80
+ cuda_q_scale = _q_scale;
81
+ cuda_q_scale_max = _q_scale_max;
82
+ cuda_q_groups = _q_groups;
83
+ cuda_gptq_qzeros = _gptq_qzeros;
84
+ cuda_gptq_scales = _gptq_scales;
85
+
86
+ is_gptq = (_gptq_qzeros != NULL);
87
+
88
+ groupsize = 1;
89
+ while (groupsize * groups < height) groupsize *= 2;
90
+
91
+ // Create group map
92
+
93
+ rows_8 = 0;
94
+ rows_6 = 0;
95
+ rows_5 = 0;
96
+ rows_4 = 0;
97
+ rows_3 = 0;
98
+ rows_2 = 0;
99
+
100
+ if (!is_gptq)
101
+ {
102
+ uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
103
+ cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
104
+
105
+ for (int i = 0; i < groups; i++)
106
+ {
107
+ int bits = cpu_q_groups[i * 2];
108
+ if (bits == 8) rows_8 += groupsize;
109
+ if (bits == 6) rows_6 += groupsize;
110
+ if (bits == 5) rows_5 += groupsize;
111
+ if (bits == 4) rows_4 += groupsize;
112
+ if (bits == 3) rows_3 += groupsize;
113
+ if (bits == 2) rows_2 += groupsize;
114
+ }
115
+
116
+ free(cpu_q_groups);
117
+
118
+ rows_6 += rows_8;
119
+ rows_5 += rows_6;
120
+ rows_4 += rows_5;
121
+ rows_3 += rows_4;
122
+ rows_2 += rows_3;
123
+ }
124
+ else
125
+ {
126
+ rows_4 = height;
127
+ rows_3 = height;
128
+ rows_2 = height;
129
+
130
+ if (_gptq_g_idx)
131
+ {
132
+ if (!make_sequential(_gptq_g_idx))
133
+ {
134
+ failed = true;
135
+ //printf("FAIL\n");
136
+ return;
137
+ }
138
+ }
139
+ }
140
+
141
+ // Shuffle quantized data
142
+
143
+ dim3 blockDim, gridDim;
144
+ blockDim.x = THREADS_X;
145
+ blockDim.y = 1;
146
+ gridDim.x = DIVIDE(width, THREADS_X);
147
+ gridDim.y = 1;
148
+
149
+ shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
150
+ }
151
+
152
+ QMatrix::~QMatrix()
153
+ {
154
+ }
155
+
156
+ // Reconstruct b[k,n] (GPTQ)
157
+
158
+ __global__ void reconstruct_gptq_kernel
159
+ (
160
+ const uint32_t* __restrict__ b_q_weight,
161
+ const uint16_t* __restrict__ b_q_perm,
162
+ const uint32_t* __restrict__ b_gptq_qzeros,
163
+ const half* __restrict__ b_gptq_scales,
164
+ //const uint16_t* __restrict__ b_q_groups,
165
+ const int size_k,
166
+ const int size_n,
167
+ const int groupsize,
168
+ const int groups,
169
+ half* __restrict__ b,
170
+ const int rows_4
171
+ )
172
+ {
173
+ MatrixView_half_rw b_(b, size_k, size_n);
174
+ MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
175
+ MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
176
+
177
+ int offset_k = BLOCK_KN_SIZE * blockIdx.y;
178
+ int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
179
+
180
+ int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
181
+
182
+ // Preload remapping table
183
+
184
+ __shared__ uint16_t perm[BLOCK_KN_SIZE];
185
+ int t = threadIdx.x;
186
+
187
+ if (b_q_perm)
188
+ {
189
+ if (offset_k + t < size_k)
190
+ perm[t] = b_q_perm[offset_k + t];
191
+ }
192
+
193
+ // Column
194
+
195
+ int n = offset_n + t * 4;
196
+ if (n >= size_n) return;
197
+
198
+ // Find initial group
199
+
200
+ int group = offset_k / groupsize;
201
+ int nextgroup = offset_k + groupsize;
202
+
203
+ // b offset
204
+
205
+ int qk = offset_k / (32 / 4);
206
+
207
+ const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
208
+
209
+ // Initial zeros/scale
210
+
211
+ int zeros[4];
212
+ half2 scales[4];
213
+ half2 z1z16[4][2];
214
+ half2 y1y16[4][2];
215
+ b_gptq_qzeros_.item4(zeros, group, n);
216
+ b_gptq_scales_.item4_h2(scales, group, n);
217
+ dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
218
+ dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
219
+ dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
220
+ dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
221
+
222
+ __syncthreads();
223
+
224
+ int k = offset_k;
225
+ int lk = 0;
226
+
227
+ while (k < end_k)
228
+ {
229
+ if (k == nextgroup)
230
+ {
231
+ group++;
232
+ nextgroup += groupsize;
233
+ b_gptq_qzeros_.item4(zeros, group, n);
234
+ b_gptq_scales_.item4_h2(scales, group, n);
235
+ dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
236
+ dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
237
+ dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
238
+ dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
239
+ }
240
+
241
+ for (int p = 0; p < 4; p++)
242
+ {
243
+ half2 dq[4][4];
244
+ const int4* b_ptr4 = (int4*) b_ptr;
245
+ int4 load_int4 = *b_ptr4;
246
+
247
+ dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
248
+ dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
249
+ dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
250
+ dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
251
+
252
+ b_ptr += size_n;
253
+ //half* dqh = (half*)dq;
254
+ if (b_q_perm)
255
+ {
256
+ for (int j = 0; j < 4; j++)
257
+ {
258
+ for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
259
+ b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
260
+ b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
261
+ }
262
+ }
263
+ else
264
+ {
265
+ for (int j = 0; j < 4; j++)
266
+ {
267
+ for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
268
+ b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
269
+ b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
270
+ }
271
+ }
272
+ }
273
+ k += 32;
274
+ }
275
+ }
276
+
277
+
278
+ // Reconstruct b[k,n]
279
+
280
+ __global__ void reconstruct_kernel
281
+ (
282
+ const uint32_t* __restrict__ b_q_weight,
283
+ const uint16_t* __restrict__ b_q_perm,
284
+ const uint32_t* __restrict__ b_q_scale,
285
+ const half* __restrict__ b_q_scale_max,
286
+ //const uint16_t* __restrict__ b_q_groups,
287
+ const int size_k,
288
+ const int size_n,
289
+ const int groupsize,
290
+ const int groups,
291
+ half* __restrict__ b,
292
+ const int rows_8,
293
+ const int rows_6,
294
+ const int rows_5,
295
+ const int rows_4,
296
+ const int rows_3,
297
+ const int rows_2
298
+ )
299
+ {
300
+ MatrixView_half_rw b_(b, size_k, size_n);
301
+ MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
302
+
303
+ int offset_k = BLOCK_KN_SIZE * blockIdx.y;
304
+ int offset_n = BLOCK_KN_SIZE * blockIdx.x;
305
+
306
+ // Preload remapping table
307
+
308
+ int t = threadIdx.x;
309
+ __shared__ uint16_t perm[BLOCK_KN_SIZE];
310
+ if (offset_k + t < size_k)
311
+ perm[t] = b_q_perm[offset_k + t];
312
+
313
+ // Column
314
+
315
+ int n = offset_n + t;
316
+ if (n >= size_n) return;
317
+
318
+ // Find initial group
319
+
320
+ int group = offset_k / groupsize;
321
+
322
+ int pre_rows_8 = min(rows_8, offset_k);
323
+ int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
324
+ int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
325
+ int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
326
+ int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
327
+ int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
328
+ int qk = 0;
329
+ qk += pre_rows_8 / 32 * 8;
330
+ qk += pre_rows_6 / 32 * 6;
331
+ qk += pre_rows_5 / 32 * 5;
332
+ qk += pre_rows_4 / 32 * 4;
333
+ qk += pre_rows_3 / 32 * 3;
334
+ qk += pre_rows_2 / 32 * 2;
335
+
336
+ const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
337
+
338
+ half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
339
+ half2 qs_h2 = __halves2half2(qs_h, qs_h);
340
+ int nextgroup = offset_k + groupsize;
341
+
342
+ int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
343
+ int k = offset_k;
344
+ int lk = 0;
345
+
346
+ __syncthreads();
347
+
348
+ while (k < rows_8 && k < end_k)
349
+ {
350
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
351
+ for (int p = 0; p < 4; p++)
352
+ {
353
+ half2 dq[4];
354
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
355
+ uint32_t q_1 = *b_ptr; b_ptr += size_n;
356
+ dequant_8bit_8(q_0, q_1, dq, size_n);
357
+ for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
358
+ half* dqh = (half*) dq;
359
+ for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
360
+ }
361
+ k += 32;
362
+ }
363
+
364
+ while (k < rows_6 && k < end_k)
365
+ {
366
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
367
+ for (int p = 0; p < 2; p++)
368
+ {
369
+ half2 dq[8];
370
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
371
+ uint32_t q_1 = *b_ptr; b_ptr += size_n;
372
+ uint32_t q_2 = *b_ptr; b_ptr += size_n;
373
+ dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
374
+ for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
375
+ half* dqh = (half*) dq;
376
+ for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
377
+ }
378
+ k += 32;
379
+ }
380
+
381
+ while (k < rows_5 && k < end_k)
382
+ {
383
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
384
+ for (int p = 0; p < 1; p++)
385
+ {
386
+ half2 dq[16];
387
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
388
+ uint32_t q_1 = *b_ptr; b_ptr += size_n;
389
+ uint32_t q_2 = *b_ptr; b_ptr += size_n;
390
+ uint32_t q_3 = *b_ptr; b_ptr += size_n;
391
+ uint32_t q_4 = *b_ptr; b_ptr += size_n;
392
+ dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
393
+ for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
394
+ half* dqh = (half*) dq;
395
+ for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
396
+ }
397
+ k += 32;
398
+ }
399
+
400
+ while (k < rows_4 && k < end_k)
401
+ {
402
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
403
+ for (int p = 0; p < 4; p++)
404
+ {
405
+ half2 dq[4];
406
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
407
+ dequant_4bit_8(q_0, dq, size_n);
408
+ for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
409
+ half* dqh = (half*) dq;
410
+ for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
411
+ }
412
+ k += 32;
413
+ }
414
+
415
+ while (k < rows_3 && k < end_k)
416
+ {
417
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
418
+ for (int p = 0; p < 1; p++)
419
+ {
420
+ half2 dq[16];
421
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
422
+ uint32_t q_1 = *b_ptr; b_ptr += size_n;
423
+ uint32_t q_2 = *b_ptr; b_ptr += size_n;
424
+ dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
425
+ for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
426
+ half* dqh = (half*) dq;
427
+ for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
428
+ }
429
+ k += 32;
430
+ }
431
+
432
+ while (k < rows_2 && k < end_k)
433
+ {
434
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
435
+ for (int p = 0; p < 2; p++)
436
+ {
437
+ half2 dq[8];
438
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
439
+ dequant_2bit_16(q_0, dq, size_n);
440
+ for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
441
+ half* dqh = (half*) dq;
442
+ for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
443
+ }
444
+ k += 32;
445
+ }
446
+ }
447
+
448
+ void QMatrix::reconstruct(half* out)
449
+ {
450
+ dim3 blockDim, gridDim;
451
+ blockDim.x = BLOCK_KN_SIZE;
452
+ blockDim.y = 1;
453
+ gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
454
+
455
+ if (!is_gptq)
456
+ {
457
+ gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
458
+ reconstruct_kernel<<<gridDim, blockDim>>>
459
+ (
460
+ cuda_q_weight,
461
+ cuda_q_perm,
462
+ cuda_q_scale,
463
+ cuda_q_scale_max,
464
+ //cuda_q_groups,
465
+ height,
466
+ width,
467
+ groupsize,
468
+ groups,
469
+ out,
470
+ rows_8,
471
+ rows_6,
472
+ rows_5,
473
+ rows_4,
474
+ rows_3,
475
+ rows_2
476
+ );
477
+ }
478
+ else
479
+ {
480
+ gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
481
+ reconstruct_gptq_kernel<<<gridDim, blockDim>>>
482
+ (
483
+ cuda_q_weight,
484
+ cuda_q_perm,
485
+ cuda_gptq_qzeros,
486
+ cuda_gptq_scales,
487
+ //const uint16_t* __restrict__ b_q_groups,
488
+ height,
489
+ width,
490
+ groupsize,
491
+ groups,
492
+ out,
493
+ rows_4
494
+ );
495
+ }
496
+ }
497
+
498
+ __global__ void make_sequential_kernel
499
+ (
500
+ const uint32_t* __restrict__ w,
501
+ uint32_t* __restrict__ w_new,
502
+ const uint16_t* __restrict__ q_perm,
503
+ const int w_height,
504
+ const int w_width
505
+ )
506
+ {
507
+ const uint64_t* w2 = (uint64_t*) w;
508
+ uint64_t* w_new2 = (uint64_t*) w_new;
509
+ int w2_stride = w_width >> 1;
510
+
511
+ int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
512
+ if (w2_column >= w2_stride) return;
513
+
514
+ int w_new2_row = blockIdx.y;
515
+
516
+ int q_perm_idx = w_new2_row << 3;
517
+
518
+ uint64_t dst = 0;
519
+
520
+ #pragma unroll
521
+ for (int i = 0; i < 8; i++)
522
+ {
523
+ int source_row = q_perm[q_perm_idx++];
524
+
525
+ int w2_row = source_row >> 3;
526
+ int w2_subrow = source_row & 0x07;
527
+ int w2_row_shift = w2_subrow << 2;
528
+ int wnew2_row_shift = i << 2;
529
+
530
+ uint64_t src = w2[w2_row * w2_stride + w2_column];
531
+ src >>= w2_row_shift;
532
+ src &= 0x0000000f0000000f;
533
+ src <<= wnew2_row_shift;
534
+ dst |= src;
535
+ }
536
+
537
+ w_new2[w_new2_row * w2_stride + w2_column] = dst;
538
+ }
539
+
540
+ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
541
+ {
542
+ uint32_t* cuda_new_qweight = NULL;
543
+ cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
544
+ if (err != cudaSuccess) {
545
+ cudaError_t cuda_status = cudaGetLastError(); // Clear error
546
+ return false;
547
+ }
548
+
549
+ uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
550
+ uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
551
+ uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
552
+
553
+ // Group histogram
554
+
555
+ for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
556
+
557
+ // Group map
558
+
559
+ for (int i = 0, acc = 0; i < groups; i++)
560
+ {
561
+ short tmp = cpu_g_idx_map[i];
562
+ cpu_g_idx_map[i] = acc;
563
+ acc += tmp;
564
+ }
565
+
566
+ // X map (inverse)
567
+
568
+ for (int row = 0; row < height; row++)
569
+ {
570
+ uint32_t target_group = cpu_g_idx[row];
571
+ uint32_t target_row = cpu_g_idx_map[target_group];
572
+ cpu_g_idx_map[target_group]++;
573
+ cpu_x_map_inv[row] = target_row;
574
+ }
575
+
576
+ // X map
577
+
578
+ for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
579
+
580
+ // Reduce to uint16_t
581
+
582
+ uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
583
+ uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
584
+ for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
585
+ for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
586
+
587
+ // Move to CUDA
588
+
589
+ cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
590
+ cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
591
+
592
+ // Rearrange rows in w
593
+
594
+ dim3 blockDim, gridDim;
595
+ blockDim.x = THREADS_X;
596
+ blockDim.y = 1;
597
+ gridDim.x = DIVIDE(width, THREADS_X);
598
+ gridDim.y = height / 8;
599
+
600
+ make_sequential_kernel<<<gridDim, blockDim>>>
601
+ (
602
+ cuda_q_weight,
603
+ cuda_new_qweight,
604
+ cuda_q_perm,
605
+ height / 8,
606
+ width
607
+ );
608
+
609
+ // Replace qweights
610
+
611
+ cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
612
+
613
+ // Cleanup
614
+
615
+ cudaDeviceSynchronize();
616
+
617
+ cudaFree(cuda_new_qweight);
618
+ free(cpu_g_idx_map);
619
+ free(cpu_x_map);
620
+ free(cpu_x_map_inv);
621
+
622
+ return true;
623
+ }
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_matrix.cuh ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _q_matrix_cuh
2
+ #define _q_matrix_cuh
3
+
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_fp16.h>
6
+ #include <cstdint>
7
+ #include <cstdio>
8
+
9
+ #define MAX_SUPERGROUPS 16
10
+
11
+ class QMatrix
12
+ {
13
+ public:
14
+
15
+ int device;
16
+ bool is_gptq;
17
+
18
+ int height;
19
+ int width;
20
+ int groups;
21
+ int groupsize;
22
+
23
+ int rows_8;
24
+ int rows_6;
25
+ int rows_5;
26
+ int rows_4;
27
+ int rows_3;
28
+ int rows_2;
29
+
30
+ uint32_t* cuda_q_weight = NULL;
31
+ uint16_t* cuda_q_perm = NULL;
32
+ uint16_t* cuda_q_invperm = NULL;
33
+ uint32_t* cuda_q_scale = NULL;
34
+ half* cuda_q_scale_max = NULL;
35
+ uint16_t* cuda_q_groups = NULL;
36
+ uint32_t* cuda_gptq_qzeros = NULL;
37
+ half* cuda_gptq_scales = NULL;
38
+
39
+ half* temp_dq;
40
+
41
+ bool failed;
42
+
43
+ QMatrix
44
+ (
45
+ const int _device,
46
+ const int _height,
47
+ const int _width,
48
+ const int _groups,
49
+
50
+ uint32_t* _q_weight,
51
+ uint16_t* _q_perm,
52
+ uint16_t* _q_invperm,
53
+ uint32_t* _q_scale,
54
+ half* _q_scale_max,
55
+ uint16_t* _q_groups,
56
+
57
+ uint32_t* _gptq_qzeros,
58
+ half* _gptq_scales,
59
+ uint32_t* _gptq_g_idx,
60
+
61
+ half* _temp_dq
62
+ );
63
+
64
+ ~QMatrix();
65
+
66
+ void reconstruct(half* out);
67
+ bool make_sequential(const uint32_t* cpu_g_idx);
68
+
69
+ private:
70
+
71
+ };
72
+
73
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_2.cuh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _qdq_2_cuh
2
+ #define _qdq_2_cuh
3
+
4
+ #include "qdq_util.cuh"
5
+ #include "../../config.h"
6
+
7
+ #if QMODE_2BIT == 1
8
+
9
+ // Permutation:
10
+ //
11
+ // ffddbb99 77553311 eeccaa88 66442200
12
+
13
+ __forceinline__ __device__ void shuffle_2bit_16
14
+ (
15
+ uint32_t* q,
16
+ int stride
17
+ )
18
+ {
19
+ uint32_t qa = q[0];
20
+ uint32_t qb = 0;
21
+
22
+ #pragma unroll
23
+ for (int i = 0; i < 8; i++)
24
+ {
25
+ uint32_t qa0 = qa & 0x03;
26
+ uint32_t qa1 = (qa & 0x0c) >> 2;
27
+ qa >>= 4;
28
+ qb |= (qa1 << (i * 2 + 16));
29
+ qb |= (qa0 << (i * 2));
30
+ }
31
+ q[0] = qb;
32
+ }
33
+
34
+ __forceinline__ __device__ void dequant_2bit_16
35
+ (
36
+ const uint32_t q_0,
37
+ half2 (&dq)[8],
38
+ int stride
39
+ )
40
+ {
41
+ const uint32_t c0 = 0x64006400;
42
+ const half y4_ = __float2half_rn(1.0f / 4.0f);
43
+ const half y16_ = __float2half_rn(1.0f / 16.0f);
44
+ const half y64_ = __float2half_rn(1.0f / 64.0f);
45
+ const half2 y4 = __halves2half2(y4_, y4_);
46
+ const half2 y16 = __halves2half2(y16_, y16_);
47
+ const half2 y64 = __halves2half2(y64_, y64_);
48
+ const half z1_ = __float2half_rn(-1024.0f - 2.0f);
49
+ const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
50
+ const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
51
+ const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
52
+ const half2 z1 = __halves2half2(z1_, z1_);
53
+ const half2 z4 = __halves2half2(z4_, z4_);
54
+ const half2 z16 = __halves2half2(z16_, z16_);
55
+ const half2 z64 = __halves2half2(z64_, z64_);
56
+
57
+ uint32_t qa = q_0;
58
+ half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
59
+ half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
60
+ half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
61
+ half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
62
+ qa >>= 8;
63
+ half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
64
+ half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
65
+ half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
66
+ half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
67
+
68
+ dq[0] = __hadd2(q0.as_half2, z1);
69
+ dq[1] = __hfma2(q1.as_half2, y4, z4);
70
+ dq[2] = __hfma2(q2.as_half2, y16, z16);
71
+ dq[3] = __hfma2(q3.as_half2, y64, z64);
72
+ dq[4] = __hadd2(q4.as_half2, z1);
73
+ dq[5] = __hfma2(q5.as_half2, y4, z4);
74
+ dq[6] = __hfma2(q6.as_half2, y16, z16);
75
+ dq[7] = __hfma2(q7.as_half2, y64, z64);
76
+ }
77
+
78
+ #else
79
+
80
+ __forceinline__ __device__ void shuffle_2bit_16
81
+ (
82
+ uint32_t* q,
83
+ int stride
84
+ )
85
+ {
86
+ }
87
+
88
+ __forceinline__ __device__ void dequant_2bit_16
89
+ (
90
+ const uint32_t q_0,
91
+ half2 (&dq)[8],
92
+ int stride
93
+ )
94
+ {
95
+ half dqh[16];
96
+ for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
97
+
98
+ for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
99
+ }
100
+
101
+ #endif
102
+
103
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_3.cuh ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _qdq_3_cuh
2
+ #define _qdq_3_cuh
3
+
4
+ #include "qdq_util.cuh"
5
+ #include "../../config.h"
6
+
7
+ #if QMODE_3BIT == 1
8
+
9
+ // Permutation:
10
+ //
11
+ // v9997775 55333111 u8886664 44222000 (u, v lsb)
12
+ // vjjjhhhf ffdddbbb uiiiggge eecccaaa
13
+ // vtttrrrp ppnnnlll usssqqqo oommmkkk
14
+
15
+ __forceinline__ __device__ void shuffle_3bit_32
16
+ (
17
+ uint32_t* q,
18
+ int stride
19
+ )
20
+ {
21
+ uint32_t qa = q[0 * stride];
22
+ uint32_t qb = q[1 * stride];
23
+ uint32_t qc = q[2 * stride];
24
+
25
+ // qa: aa999888 77766655 54443332 22111000
26
+ // qb: lkkkjjji iihhhggg fffeeedd dcccbbba
27
+ // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
28
+
29
+ uint32_t qd = qc >> 26;
30
+ qc <<= 4;
31
+ qc |= qb >> 28;
32
+ qb <<= 2;
33
+ qb |= qa >> 30;
34
+
35
+ // qa: ..999888 77766655 54443332 22111000
36
+ // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
37
+ // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
38
+ // qd: vvvuuu
39
+
40
+ uint32_t za = 0;
41
+ uint32_t zb = 0;
42
+ uint32_t zc = 0;
43
+
44
+ for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
45
+ for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
46
+ for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
47
+
48
+ // za: 9997775 55333111 8886664 44222000
49
+ // zb: jjjhhhf ffdddbbb iiiggge eecccaaa
50
+ // zc: tttrrrp ppnnnlll sssqqqo oommmkkk
51
+ // qd: vvvuuu
52
+
53
+ za |= ((qd & 0x01) >> 0) << 15;
54
+ zb |= ((qd & 0x02) >> 1) << 15;
55
+ zc |= ((qd & 0x04) >> 2) << 15;
56
+ za |= ((qd & 0x08) >> 3) << 31;
57
+ zb |= ((qd & 0x10) >> 4) << 31;
58
+ zc |= ((qd & 0x20) >> 5) << 31;
59
+
60
+ // za: v9997775 55333111 u8886664 44222000 (u, v lsb)
61
+ // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
62
+ // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
63
+
64
+ q[0 * stride] = za;
65
+ q[1 * stride] = zb;
66
+ q[2 * stride] = zc;
67
+ }
68
+
69
+ __forceinline__ __device__ void dequant_3bit_32
70
+ (
71
+ const uint32_t q_0,
72
+ const uint32_t q_1,
73
+ const uint32_t q_2,
74
+ half2 (&dq)[16],
75
+ int stride
76
+ )
77
+ {
78
+ const uint32_t c0 = 0x64006400;
79
+ const half y8_ = __float2half_rn(1.0f / 8.0f);
80
+ const half y64_ = __float2half_rn(1.0f / 64.0f);
81
+ const half2 y8 = __halves2half2(y8_, y8_);
82
+ const half2 y64 = __halves2half2(y64_, y64_);
83
+ const half z1_ = __float2half_rn(-1024.0f - 4.0f);
84
+ const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
85
+ const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
86
+ const half2 z1 = __halves2half2(z1_, z1_);
87
+ const half2 z8 = __halves2half2(z8_, z8_);
88
+ const half2 z64 = __halves2half2(z64_, z64_);
89
+
90
+ uint32_t qa = q_0;
91
+ uint32_t qb = q_1;
92
+ uint32_t qc = q_2;
93
+
94
+ half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
95
+ half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
96
+ qa >>= 6;
97
+ half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
98
+ half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
99
+ half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
100
+ qa >>= 9;
101
+ qa &= 0x00010001;
102
+ half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
103
+ half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
104
+ qb >>= 6;
105
+ half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
106
+ half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
107
+ half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
108
+ qb >>= 8;
109
+ qb &= 0x00020002;
110
+ half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
111
+ half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
112
+ qc >>= 6;
113
+ half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
114
+ half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
115
+ half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
116
+ qc >>= 7;
117
+ qc &= 0x00040004;
118
+ half2_uint32 q15((qa | qb | qc) | c0);
119
+
120
+ dq[ 0] = __hadd2( q0.as_half2, z1);
121
+ dq[ 1] = __hfma2( q1.as_half2, y8, z8);
122
+ dq[ 2] = __hadd2( q2.as_half2, z1);
123
+ dq[ 3] = __hfma2( q3.as_half2, y8, z8);
124
+ dq[ 4] = __hfma2( q4.as_half2, y64, z64);
125
+ dq[ 5] = __hadd2( q5.as_half2, z1);
126
+ dq[ 6] = __hfma2( q6.as_half2, y8, z8);
127
+ dq[ 7] = __hadd2( q7.as_half2, z1);
128
+ dq[ 8] = __hfma2( q8.as_half2, y8, z8);
129
+ dq[ 9] = __hfma2( q9.as_half2, y64, z64);
130
+ dq[10] = __hadd2(q10.as_half2, z1);
131
+ dq[11] = __hfma2(q11.as_half2, y8, z8);
132
+ dq[12] = __hadd2(q12.as_half2, z1);
133
+ dq[13] = __hfma2(q13.as_half2, y8, z8);
134
+ dq[14] = __hfma2(q14.as_half2, y64, z64);
135
+ dq[15] = __hadd2(q15.as_half2, z1);
136
+ }
137
+
138
+ #else
139
+
140
+ __forceinline__ __device__ void shuffle_3bit_32
141
+ (
142
+ uint32_t* q,
143
+ int stride
144
+ )
145
+ {
146
+ }
147
+
148
+ __forceinline__ __device__ void dequant_3bit_32
149
+ (
150
+ const uint32_t q_0,
151
+ const uint32_t q_1,
152
+ const uint32_t q_2,
153
+ half2 (&dq)[16],
154
+ int stride
155
+ )
156
+ {
157
+ half dqh[32];
158
+ for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
159
+ dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
160
+ for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
161
+ dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
162
+ for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
163
+
164
+ for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
165
+ }
166
+
167
+ #endif
168
+
169
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_4.cuh ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _qdq_4_cuh
2
+ #define _qdq_4_cuh
3
+
4
+ #include "qdq_util.cuh"
5
+ #include "../../config.h"
6
+
7
+ #if QMODE_4BIT == 1
8
+
9
+ // Permutation:
10
+ //
11
+ // 77775555 33331111 66664444 22220000
12
+
13
+ __forceinline__ __device__ void shuffle_4bit_8
14
+ (
15
+ uint32_t* q,
16
+ int stride
17
+ )
18
+ {
19
+ uint32_t qa = q[0];
20
+ uint32_t qb = 0;
21
+
22
+ #pragma unroll
23
+ for (int i = 0; i < 4; i++)
24
+ {
25
+ uint32_t qa0 = qa & 0x0f;
26
+ uint32_t qa1 = (qa & 0xf0) >> 4;
27
+ qa >>= 8;
28
+ qb |= (qa1 << (i * 4 + 16));
29
+ qb |= (qa0 << (i * 4));
30
+ }
31
+ q[0] = qb;
32
+ }
33
+
34
+ __forceinline__ __device__ void dequant_4bit_8
35
+ (
36
+ const uint32_t q_0,
37
+ half2 (&dq)[4],
38
+ int stride
39
+ )
40
+ {
41
+ const uint32_t c0 = 0x64006400;
42
+ const half y16_ = __float2half_rn(1.0f / 16.0f);
43
+ const half2 y16 = __halves2half2(y16_, y16_);
44
+ const half z1_ = __float2half_rn(-1024.0f - 8.0f);
45
+ const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
46
+ const half2 z1 = __halves2half2(z1_, z1_);
47
+ const half2 z16 = __halves2half2(z16_, z16_);
48
+
49
+ uint32_t qa = q_0;
50
+ half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
51
+ half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
52
+ qa >>= 8;
53
+ half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
54
+ half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
55
+
56
+ dq[0] = __hadd2(q0.as_half2, z1);
57
+ dq[1] = __hfma2(q1.as_half2, y16, z16);
58
+ dq[2] = __hadd2(q2.as_half2, z1);
59
+ dq[3] = __hfma2(q3.as_half2, y16, z16);
60
+ }
61
+
62
+ __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
63
+ (
64
+ const uint32_t zero,
65
+ const half scale,
66
+ half2 (&z1z16)[2],
67
+ half2 (&y1y16)[2]
68
+ )
69
+ {
70
+ half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
71
+ half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
72
+
73
+ half2 scale2 = __half2half2(scale);
74
+
75
+ z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
76
+ z1z16[1] = __hmul2(scale2, __half2half2(z16));
77
+
78
+ const half y1 = __float2half_rn(1.0f);
79
+ const half y16 = __float2half_rn(1.0f / 16.0f);
80
+
81
+ y1y16[0] = __hmul2(scale2, __half2half2(y1));
82
+ y1y16[1] = __hmul2(scale2, __half2half2(y16));
83
+ }
84
+
85
+ __forceinline__ __device__ void dequant_4bit_8_prep_zero
86
+ (
87
+ const uint32_t zero,
88
+ half2(&z1z16)[2],
89
+ half2(&y1y16)[2]
90
+ )
91
+ {
92
+ half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
93
+ half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
94
+
95
+ z1z16[0] = __half2half2(z1.as_half);
96
+ z1z16[1] = __half2half2(z16);
97
+
98
+ const half y1 = __float2half_rn(1.0f);
99
+ const half y16 = __float2half_rn(1.0f / 16.0f);
100
+
101
+ y1y16[0] = __half2half2(y1);
102
+ y1y16[1] = __half2half2(y16);
103
+ }
104
+
105
+
106
+ __forceinline__ __device__ void dequant_4bit_8_gptq
107
+ (
108
+ const uint32_t q_0,
109
+ half2 (&dq)[4],
110
+ half2 (&z1z16)[2],
111
+ half2 (&y1y16)[2],
112
+ int stride,
113
+ bool scaled
114
+ )
115
+ {
116
+ const uint32_t c0 = 0x64006400;
117
+
118
+ uint32_t qa = q_0;
119
+ half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
120
+ half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
121
+ qa >>= 8;
122
+ half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
123
+ half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
124
+
125
+ if (scaled)
126
+ {
127
+ dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
128
+ dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
129
+ dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
130
+ dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
131
+ }
132
+ else
133
+ {
134
+ dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
135
+ dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
136
+ dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
137
+ dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
138
+ }
139
+ }
140
+
141
+ #else
142
+
143
+ __forceinline__ __device__ void shuffle_4bit_8
144
+ (
145
+ uint32_t* q,
146
+ int stride
147
+ )
148
+ {
149
+ }
150
+
151
+ __forceinline__ __device__ void dequant_4bit_8
152
+ (
153
+ const uint32_t q_0,
154
+ half2 (&dq)[4],
155
+ int stride
156
+ )
157
+ {
158
+ half dqh[8];
159
+ for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
160
+
161
+ for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
162
+ }
163
+
164
+ __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
165
+ (
166
+ const uint32_t zero,
167
+ const half scale,
168
+ half2 (&z1)[2],
169
+ half2 (&y1)[2]
170
+ )
171
+ {
172
+ half z = __int2half_rn(-((int)zero));
173
+ z = __hmul(z, scale);
174
+ z1[0] = __half2half2(z);
175
+ y1[0] = __half2half2(scale);
176
+ }
177
+
178
+ __forceinline__ __device__ void dequant_4bit_8_prep_zero
179
+ (
180
+ const uint32_t zero,
181
+ half2(&z1)[2],
182
+ half2(&y1)[2]
183
+ )
184
+ {
185
+ half z = __int2half_rn(-((int)zero));
186
+ z1[0] = __half2half2(z);
187
+ }
188
+
189
+ __forceinline__ __device__ void dequant_4bit_8_gptq
190
+ (
191
+ const uint32_t q_0,
192
+ half2 (&dq)[4],
193
+ half2 (&z1)[2],
194
+ half2 (&y1)[2],
195
+ int stride,
196
+ bool scaled
197
+ )
198
+ {
199
+ half2 dqh2[8];
200
+
201
+ uint32_t qa = q_0;
202
+ for (int i = 0; i < 4; i++)
203
+ {
204
+ half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
205
+ half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
206
+ dqh2[i] = __halves2half2(d0, d1);
207
+ }
208
+
209
+ if (scaled)
210
+ {
211
+ dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
212
+ dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
213
+ dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
214
+ dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
215
+ }
216
+ else
217
+ {
218
+ dq[0] = __hadd2(dqh2[0], z1[0]);
219
+ dq[1] = __hadd2(dqh2[1], z1[0]);
220
+ dq[2] = __hadd2(dqh2[2], z1[0]);
221
+ dq[3] = __hadd2(dqh2[3], z1[0]);
222
+ }
223
+ }
224
+
225
+ #endif
226
+
227
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_5.cuh ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _qdq_5_cuh
2
+ #define _qdq_5_cuh
3
+
4
+ #include "qdq_util.cuh"
5
+ #include "../../config.h"
6
+
7
+ #if QMODE_5BIT == 1
8
+
9
+ // Permutation:
10
+ //
11
+ // v5555533 33311111 u4444422 22200000 (u, v lsb)
12
+ // vbbbbb99 99977777 uaaaaa88 88866666
13
+ // vhhhhhff fffddddd ugggggee eeeccccc
14
+ // vnnnnnll llljjjjj ummmmmkk kkkiiiii
15
+ // vtttttrr rrrppppp usssssqq qqqooooo
16
+
17
+ __forceinline__ __device__ void shuffle_5bit_32
18
+ (
19
+ uint32_t* q,
20
+ int stride
21
+ )
22
+ {
23
+ uint32_t qa = q[0 * stride];
24
+ uint32_t qb = q[1 * stride];
25
+ uint32_t qc = q[2 * stride];
26
+ uint32_t qd = q[3 * stride];
27
+ uint32_t qe = q[4 * stride];
28
+
29
+ // qa: 66555554 44443333 32222211 11100000
30
+ // qb: ccccbbbb baaaaa99 99988888 77777666
31
+ // qc: jiiiiihh hhhggggg fffffeee eedddddc
32
+ // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
33
+ // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
34
+
35
+ uint32_t qf = qe >> 22;
36
+ qe <<= 8;
37
+ qe |= qd >> 24;
38
+ qd <<= 6;
39
+ qd |= qc >> 26;
40
+ qc <<= 4;
41
+ qc |= qb >> 28;
42
+ qb <<= 2;
43
+ qb |= qa >> 30;
44
+
45
+ // qa: 555554 44443333 32222211 11100000
46
+ // qb: bbbbba aaaa9999 98888877 77766666
47
+ // qc: hhhhhg ggggffff feeeeedd dddccccc
48
+ // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
49
+ // qe: ttttts ssssrrrr rqqqqqpp pppooooo
50
+ // qf: vv vvvuuuuu
51
+
52
+ uint32_t za = 0;
53
+ uint32_t zb = 0;
54
+ uint32_t zc = 0;
55
+ uint32_t zd = 0;
56
+ uint32_t ze = 0;
57
+
58
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
59
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
60
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
61
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
62
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
63
+
64
+ // za: 5555533 33311111 4444422 22200000
65
+ // zb: bbbbb99 99977777 aaaaa88 88866666
66
+ // zc: hhhhhff fffddddd gggggee eeeccccc
67
+ // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
68
+ // ze: tttttrr rrrppppp sssssqq qqqooooo
69
+ // qf: vv vvvuuuuu
70
+
71
+ za |= ((qf & 0x001) >> 0) << 15;
72
+ zb |= ((qf & 0x002) >> 1) << 15;
73
+ zc |= ((qf & 0x004) >> 2) << 15;
74
+ zd |= ((qf & 0x008) >> 3) << 15;
75
+ ze |= ((qf & 0x010) >> 4) << 15;
76
+ za |= ((qf & 0x020) >> 5) << 31;
77
+ zb |= ((qf & 0x040) >> 6) << 31;
78
+ zc |= ((qf & 0x080) >> 7) << 31;
79
+ zd |= ((qf & 0x100) >> 8) << 31;
80
+ ze |= ((qf & 0x200) >> 9) << 31;
81
+
82
+ // za: v5555533 33311111 u4444422 22200000 (u, v lsb)
83
+ // zb: vbbbbb99 99977777 uaaaaa88 88866666
84
+ // zc: vhhhhhff fffddddd ugggggee eeeccccc
85
+ // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
86
+ // ze: vtttttrr rrrppppp usssssqq qqqooooo
87
+
88
+ q[0 * stride] = za;
89
+ q[1 * stride] = zb;
90
+ q[2 * stride] = zc;
91
+ q[3 * stride] = zd;
92
+ q[4 * stride] = ze;
93
+ }
94
+
95
+ __forceinline__ __device__ void dequant_5bit_32
96
+ (
97
+ const uint32_t q_0,
98
+ const uint32_t q_1,
99
+ const uint32_t q_2,
100
+ const uint32_t q_3,
101
+ const uint32_t q_4,
102
+ half2 (&dq)[16],
103
+ int stride
104
+ )
105
+ {
106
+ const uint32_t c0 = 0x64006400;
107
+ const half y32_ = __float2half_rn(1.0f / 32.0f);
108
+ const half2 y32 = __halves2half2(y32_, y32_);
109
+ const half z1_ = __float2half_rn(-1024.0f - 16.0f);
110
+ const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
111
+ const half2 z1 = __halves2half2(z1_, z1_);
112
+ const half2 z32 = __halves2half2(z32_, z32_);
113
+
114
+ uint32_t qa = q_0;
115
+ uint32_t qb = q_1;
116
+ uint32_t qc = q_2;
117
+ uint32_t qd = q_3;
118
+ uint32_t qe = q_4;
119
+
120
+ half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
121
+ half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
122
+ qa >>= 10;
123
+ half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
124
+ qa >>= 5;
125
+ qa &= 0x00010001;
126
+ half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
127
+ half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
128
+ qb >>= 10;
129
+ half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
130
+ qb >>= 4;
131
+ qb &= 0x00020002;
132
+ half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
133
+ half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
134
+ qc >>= 10;
135
+ half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
136
+ qc >>= 3;
137
+ qc &= 0x00040004;
138
+ half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
139
+ half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
140
+ qd >>= 10;
141
+ half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
142
+ qd >>= 2;
143
+ qd &= 0x00080008;
144
+ half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
145
+ half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
146
+ qe >>= 10;
147
+ half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
148
+ qe >>= 1;
149
+ qe &= 0x00100010;
150
+ half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
151
+
152
+ dq[ 0] = __hadd2( q0.as_half2, z1);
153
+ dq[ 1] = __hfma2( q1.as_half2, y32, z32);
154
+ dq[ 2] = __hadd2( q2.as_half2, z1);
155
+ dq[ 3] = __hadd2( q3.as_half2, z1);
156
+ dq[ 4] = __hfma2( q4.as_half2, y32, z32);
157
+ dq[ 5] = __hadd2( q5.as_half2, z1);
158
+ dq[ 6] = __hadd2( q6.as_half2, z1);
159
+ dq[ 7] = __hfma2( q7.as_half2, y32, z32);
160
+ dq[ 8] = __hadd2( q8.as_half2, z1);
161
+ dq[ 9] = __hadd2( q9.as_half2, z1);
162
+ dq[10] = __hfma2(q10.as_half2, y32, z32);
163
+ dq[11] = __hadd2(q11.as_half2, z1);
164
+ dq[12] = __hadd2(q12.as_half2, z1);
165
+ dq[13] = __hfma2(q13.as_half2, y32, z32);
166
+ dq[14] = __hadd2(q14.as_half2, z1);
167
+ dq[15] = __hadd2(q15.as_half2, z1);
168
+ }
169
+
170
+ #else
171
+
172
+ __forceinline__ __device__ void shuffle_5bit_32
173
+ (
174
+ uint32_t* q,
175
+ int stride
176
+ )
177
+ {
178
+ }
179
+
180
+ __forceinline__ __device__ void dequant_5bit_32
181
+ (
182
+ const uint32_t q_0,
183
+ const uint32_t q_1,
184
+ const uint32_t q_2,
185
+ const uint32_t q_3,
186
+ const uint32_t q_4,
187
+ half2 (&dq)[16],
188
+ int stride
189
+ )
190
+ {
191
+ half dqh[32];
192
+ for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
193
+ dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
194
+ for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
195
+ dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
196
+ for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
197
+ dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
198
+ for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
199
+ dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
200
+ for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
201
+
202
+ for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
203
+ }
204
+
205
+ #endif
206
+
207
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_6.cuh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _qdq_6_cuh
2
+ #define _qdq_6_cuh
3
+
4
+ #include "qdq_util.cuh"
5
+ #include "../../config.h"
6
+
7
+ #if QMODE_6BIT == 1
8
+
9
+ // Not implemented
10
+
11
+ #else
12
+
13
+ __forceinline__ __device__ void shuffle_6bit_16
14
+ (
15
+ uint32_t* q,
16
+ int stride
17
+ )
18
+ {
19
+ }
20
+
21
+ __forceinline__ __device__ void dequant_6bit_16
22
+ (
23
+ const uint32_t q_0,
24
+ const uint32_t q_1,
25
+ const uint32_t q_2,
26
+ half2 (&dq)[8],
27
+ int stride
28
+ )
29
+ {
30
+ half dqh[16];
31
+ for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
32
+ dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
33
+ for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
34
+ dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
35
+ for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
36
+
37
+ for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
38
+ }
39
+
40
+ #endif
41
+
42
+ #endif
43
+
44
+
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_8.cuh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _qdq_8_cuh
2
+ #define _qdq_8_cuh
3
+
4
+ #include "qdq_util.cuh"
5
+ #include "../../config.h"
6
+
7
+ #if QMODE_8BIT == 1
8
+
9
+ // Not implemented
10
+
11
+ #else
12
+
13
+ __forceinline__ __device__ void shuffle_8bit_4
14
+ (
15
+ uint32_t* q,
16
+ int stride
17
+ )
18
+ {
19
+ }
20
+
21
+ __forceinline__ __device__ void dequant_8bit_8
22
+ (
23
+ const uint32_t q_0,
24
+ const uint32_t q_1,
25
+ half2 (&dq)[4],
26
+ int stride
27
+ )
28
+ {
29
+ half dqh[8];
30
+ for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
31
+ for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
32
+
33
+ for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
34
+ }
35
+
36
+ #endif
37
+
38
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_util.cuh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _qdq_util_cuh
2
+ #define _qdq_util_cuh
3
+
4
+ union half2_uint32
5
+ {
6
+ uint32_t as_uint32;
7
+ half2 as_half2;
8
+ __device__ half2_uint32(uint32_t val) : as_uint32(val) {}
9
+ __device__ half2_uint32(half2 val) : as_half2(val) {}
10
+ };
11
+
12
+ union half_uint16
13
+ {
14
+ uint16_t as_uint16;
15
+ half as_half;
16
+ __device__ half_uint16(uint16_t val) : as_uint16(val) {}
17
+ __device__ half_uint16(half val) : as_half(val) {}
18
+ };
19
+
20
+ // Max_scale premultiplied by 1/256
21
+
22
+ __forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
23
+ {
24
+ int qs_i = qs + 1;
25
+ half qs_h = __int2half_rn(qs_i * qs_i);
26
+ qs_h = __hmul(qs_h, max_scale);
27
+ return qs_h;
28
+ }
29
+
30
+ __forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
31
+ {
32
+ return __hmul(__int2half_rn(q - qzero), scale);
33
+ }
34
+
35
+ __forceinline__ __device__ half dq_ns(const int q, const int qzero)
36
+ {
37
+ //return __hsub(__int2half_rn(q), __int2half_rn(qzero));
38
+ return __int2half_rn(q - qzero);
39
+ }
40
+
41
+ __forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
42
+ {
43
+ return (int)((q >> shift) & mask);
44
+ }
45
+
46
+ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
47
+ {
48
+ return (int)(__funnelshift_rc(q0, q1, shift) & mask);
49
+ }
50
+
51
+ #endif
AutoAWQ_kernels/awq_ext/exllamav2/cuda/util.cuh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
3
+
4
+ #define DBGS(__x) printf("%s\n", __x)
5
+ #define DBGI(__x) printf("%s: %i\n", #__x, __x)
6
+ #define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
7
+ #define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
8
+ #define DBGX(__x) printf("%s: %x\n", #__x, __x)
9
+ #define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
10
+ #define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
11
+ #define DBGF(__x) printf("%s: %f\n", #__x, __x)
12
+ #define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
13
+ #define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
14
+ #define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
15
+ #define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
16
+ #define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
17
+
18
+ #define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
19
+ #define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
20
+
21
+ __forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
22
+ {
23
+ half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
24
+ qs_h = __hmul(qs_h, qs_h);
25
+ qs_h = __hmul(qs_h, max_scale);
26
+ return qs_h;
27
+ }
28
+
29
+ __forceinline__ __device__ float clamp(float x, float a, float b)
30
+ {
31
+ return fmaxf(a, fminf(b, x));
32
+ }
33
+
34
+ #define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
35
+ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
36
+ {
37
+ if (code != cudaSuccess)
38
+ {
39
+ fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
40
+ if (abort) exit(code);
41
+ }
42
+ }
AutoAWQ_kernels/awq_ext/exllamav2/ext.cpp ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_fp16.h>
6
+ #include <cstdint>
7
+ #include <cstdio>
8
+
9
+ #include "config.h"
10
+
11
+ #include "cuda/q_matrix.cuh"
12
+ #include "cuda/q_gemm.cuh"
13
+
14
+ #include "cpp/util.h"
15
+
16
+ // Some decluttering macros
17
+
18
+ #define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
19
+ #define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
20
+ #define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
21
+ #define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
22
+
23
+
24
+ // Quant matrix
25
+
26
+ uintptr_t make_q_matrix
27
+ (
28
+ torch::Tensor q_weight,
29
+ torch::Tensor q_perm,
30
+ torch::Tensor q_invperm,
31
+ torch::Tensor q_scale,
32
+ torch::Tensor q_scale_max,
33
+ torch::Tensor q_groups,
34
+ torch::Tensor gptq_qzeros,
35
+ torch::Tensor gptq_scales,
36
+ torch::Tensor gptq_g_idx,
37
+ torch::Tensor temp_dq
38
+ )
39
+ {
40
+ TORCH_CHECK_DTYPE(q_weight, kInt);
41
+ TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
42
+ TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
43
+ TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
44
+ TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
45
+ TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
46
+ TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
47
+ TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
48
+ TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
49
+
50
+ TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
51
+
52
+ int device = q_weight.device().index();
53
+ int width = q_weight.size(1);
54
+ int groups;
55
+ int height;
56
+
57
+ if (!q_scale.device().is_meta())
58
+ {
59
+ TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
60
+ TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
61
+ groups = q_scale.size(0);
62
+ height = q_invperm.size(0);
63
+ }
64
+ else
65
+ {
66
+ TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
67
+ TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
68
+ groups = gptq_qzeros.size(0);
69
+ height = q_weight.size(0) * 8;
70
+ }
71
+
72
+ TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
73
+
74
+ QMatrix* m = new QMatrix
75
+ (
76
+ device,
77
+ height,
78
+ width,
79
+ groups,
80
+ (uint32_t*) q_weight.data_ptr(),
81
+ q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
82
+ q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
83
+ q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
84
+ q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
85
+ q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
86
+ gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
87
+ gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
88
+ gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
89
+ (half*) temp_dq.data_ptr()
90
+ );
91
+
92
+ return reinterpret_cast<uintptr_t> (m);
93
+ }
94
+
95
+ void gemm_half_q_half
96
+ (
97
+ torch::Tensor a,
98
+ uintptr_t b,
99
+ torch::Tensor c,
100
+ bool force_cuda
101
+ )
102
+ {
103
+ QMatrix* qm = reinterpret_cast<QMatrix*> (b);
104
+
105
+ TORCH_CHECK_DTYPE(a, kHalf);
106
+ TORCH_CHECK_DTYPE(c, kHalf);
107
+ TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
108
+ TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
109
+ TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
110
+
111
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
112
+
113
+ gemm_half_q_half_cuda
114
+ (
115
+ at::cuda::getCurrentCUDABlasHandle(),
116
+ (const half*) a.data_ptr(),
117
+ qm,
118
+ (half*) c.data_ptr(),
119
+ c.size(0), // m
120
+ c.size(1), // n
121
+ a.size(1), // k
122
+ true,
123
+ NULL,
124
+ force_cuda
125
+ );
126
+ }
127
+
128
+ // Bindings
129
+
130
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
131
+ {
132
+ m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
133
+ m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
134
+ }
AutoAWQ_kernels/awq_ext/layernorm/layernorm.cu ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+
3
+ Adapted from NVIDIA FasterTransformer:
4
+ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
5
+
6
+ */
7
+
8
+ #include <torch/extension.h>
9
+ #include <cuda_fp16.h>
10
+ #include "reduction.cuh"
11
+ #include "layernorm.h"
12
+ #include <cuda_runtime.h>
13
+ #include <c10/cuda/CUDAGuard.h>
14
+
15
+ static inline __device__ float to_float(half src)
16
+ {
17
+ return __half2float(src);
18
+ }
19
+
20
+ static inline __device__ float to_float(float src)
21
+ {
22
+ return src;
23
+ }
24
+
25
+ template<typename T>
26
+ __global__ void generalT5LayerNorm(
27
+ const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n)
28
+ {
29
+ // layernorm module in the T5 style No bias and no subtraction of mean.
30
+ const int tid = threadIdx.x;
31
+
32
+ __shared__ float s_variance;
33
+ float variance = 0.0f;
34
+
35
+ float local_var_sum = 0.0f;
36
+ for (int i = tid; i < n; i += blockDim.x) {
37
+ float diff = to_float(__ldg(&input[blockIdx.x * n + i]));
38
+ local_var_sum += diff * diff;
39
+ }
40
+ variance = blockReduceSum(local_var_sum);
41
+
42
+ if (threadIdx.x == 0) {
43
+ s_variance = rsqrtf(variance / (float)n + layernorm_eps);
44
+ }
45
+ __syncthreads();
46
+
47
+ for (int i = tid; i < n; i += blockDim.x) {
48
+ output[blockIdx.x * n + i] =
49
+ clamp_inf_for_half<T>((to_float(input[blockIdx.x * n + i]) * s_variance) * to_float(__ldg(&gamma[i])));
50
+ }
51
+ }
52
+
53
+
54
+ template<typename T>
55
+ void invokeGeneralT5LayerNorm(T* out,
56
+ const T* input,
57
+ const T* gamma,
58
+ // const T* beta,
59
+ const float layernorm_eps,
60
+ const int m,
61
+ const int n)
62
+ {
63
+ dim3 grid(m);
64
+ dim3 block(min(n, 1024));
65
+
66
+ /* For general cases, n is equal to hidden_units, e.g., 512/1024.
67
+ Since we have warp shuffle inside the code, block.x % 32 should be 0.
68
+ */
69
+ if (n % 32 != 0) {
70
+ block.x = 1024;
71
+ }
72
+
73
+ block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x
74
+
75
+ /* should pay attention to the rsqrt precision*/
76
+ generalT5LayerNorm<T><<<grid, block>>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3
77
+ }
78
+
79
+ template void invokeGeneralT5LayerNorm(half* out,
80
+ const half* input,
81
+ const half* gamma,
82
+ // const half* beta,
83
+ const float layernorm_eps,
84
+ const int m,
85
+ const int n);
86
+
87
+ template void invokeGeneralT5LayerNorm(float* out,
88
+ const float* input,
89
+ const float* gamma,
90
+ // const half* beta,
91
+ const float layernorm_eps,
92
+ const int m,
93
+ const int n);
94
+
95
+
96
+
97
+ // input b, n, c
98
+ void layernorm_forward_cuda(
99
+ torch::Tensor _input,
100
+ torch::Tensor _gamma,
101
+ torch::Tensor _out,
102
+ float eps)
103
+ {
104
+ int m = _input.size(0) * _input.size(1);
105
+ int n = _input.size(2);
106
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(_input));
107
+
108
+ auto input = reinterpret_cast<half*>(_input.data_ptr<at::Half>());
109
+ auto gamma = reinterpret_cast<half*>(_gamma.data_ptr<at::Half>());
110
+ auto out = reinterpret_cast<half*>(_out.data_ptr<at::Half>());
111
+
112
+ invokeGeneralT5LayerNorm(out, input, gamma, eps, m, n);
113
+ }
AutoAWQ_kernels/awq_ext/layernorm/layernorm.h ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ void layernorm_forward_cuda(torch::Tensor _input, torch::Tensor _gamma, torch::Tensor _out, float eps);
AutoAWQ_kernels/awq_ext/layernorm/reduction.cuh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+
3
+ Adapted from NVIDIA FasterTransformer:
4
+ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh
5
+ */
6
+
7
+ #pragma once
8
+ #include <assert.h>
9
+ #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
10
+ #include <cooperative_groups/reduce.h>
11
+ #else
12
+ #include <cooperative_groups.h>
13
+ #endif
14
+ #include <cuda_fp16.h>
15
+ #include <cuda_runtime.h>
16
+ #include <float.h>
17
+ #include <type_traits>
18
+
19
+ #define HALF_FLT_MAX 65504.F
20
+ #define FINAL_MASK 0xffffffff
21
+
22
+
23
+ template<typename T>
24
+ inline __device__ T add(T a, T b) {
25
+ return a + b;
26
+ }
27
+
28
+ template<>
29
+ inline __device__ half2 add(half2 a, half2 b) {
30
+ return __hadd2(a, b);
31
+ }
32
+
33
+ template<>
34
+ inline __device__ half add(half a, half b) {
35
+ return __hadd(a, b);
36
+ }
37
+
38
+ template<typename T>
39
+ __inline__ __device__ T warpReduceSum(T val)
40
+ {
41
+ #pragma unroll
42
+ for (int mask = 16; mask > 0; mask >>= 1)
43
+ val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
44
+ return val;
45
+ }
46
+
47
+ /* Calculate the sum of all elements in a block */
48
+ template<typename T>
49
+ __inline__ __device__ T blockReduceSum(T val)
50
+ {
51
+ static __shared__ T shared[32];
52
+ int lane = threadIdx.x & 0x1f;
53
+ int wid = threadIdx.x >> 5;
54
+
55
+ val = warpReduceSum<T>(val);
56
+
57
+ if (lane == 0)
58
+ shared[wid] = val;
59
+
60
+ __syncthreads();
61
+
62
+ // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
63
+ // blockDim.x is not divided by 32
64
+ val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
65
+ val = warpReduceSum<T>(val);
66
+
67
+ return val;
68
+ }
69
+
70
+
71
+ template<typename T>
72
+ __device__ __forceinline__ T clamp_inf_for_half(const float input)
73
+ {
74
+ return input;
75
+ }
76
+
77
+ template<>
78
+ __device__ __forceinline__ half clamp_inf_for_half(const float input)
79
+ {
80
+ // clamp inf values to enable fp16 training
81
+ return input > 0.0f ? __float2half(min(input, HALF_FLT_MAX - 1000)) : __float2half(max(input, -HALF_FLT_MAX + 1000));
82
+ }