Training in progress, step 500
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- AutoAWQ_kernels/.github/workflows/build.yaml +232 -0
- AutoAWQ_kernels/.github/workflows/scripts/github_create_release.js +17 -0
- AutoAWQ_kernels/.gitignore +164 -0
- AutoAWQ_kernels/LICENSE +21 -0
- AutoAWQ_kernels/README.md +42 -0
- AutoAWQ_kernels/awq_ext/attention/cuda_bf16_fallbacks.cuh +257 -0
- AutoAWQ_kernels/awq_ext/attention/cuda_bf16_wrapper.h +23 -0
- AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention.cu +152 -0
- AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention.h +184 -0
- AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention_template.hpp +1608 -0
- AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention_utils.h +1786 -0
- AutoAWQ_kernels/awq_ext/attention/ft_attention.cpp +182 -0
- AutoAWQ_kernels/awq_ext/attention/ft_attention.h +15 -0
- AutoAWQ_kernels/awq_ext/exllama/cu_compat.cuh +58 -0
- AutoAWQ_kernels/awq_ext/exllama/cuda_buffers.cu +75 -0
- AutoAWQ_kernels/awq_ext/exllama/cuda_buffers.cuh +55 -0
- AutoAWQ_kernels/awq_ext/exllama/cuda_func/column_remap.cu +63 -0
- AutoAWQ_kernels/awq_ext/exllama/cuda_func/column_remap.cuh +19 -0
- AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matmul.cu +260 -0
- AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matmul.cuh +43 -0
- AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matrix.cu +227 -0
- AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matrix.cuh +53 -0
- AutoAWQ_kernels/awq_ext/exllama/exllama_ext.cpp +260 -0
- AutoAWQ_kernels/awq_ext/exllama/hip_compat.cuh +51 -0
- AutoAWQ_kernels/awq_ext/exllama/matrix.cuh +294 -0
- AutoAWQ_kernels/awq_ext/exllama/tuning.h +13 -0
- AutoAWQ_kernels/awq_ext/exllama/util.cuh +33 -0
- AutoAWQ_kernels/awq_ext/exllamav2/config.h +13 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cpp/util.h +12 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/compat.cuh +56 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/compat_gemm.cuh +38 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/matrix_view.cuh +121 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm.cu +211 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm.cuh +33 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm_kernel.cuh +487 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm_kernel_gptq.cuh +219 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_matrix.cu +623 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_matrix.cuh +73 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_2.cuh +103 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_3.cuh +169 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_4.cuh +227 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_5.cuh +207 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_6.cuh +44 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_8.cuh +38 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_util.cuh +51 -0
- AutoAWQ_kernels/awq_ext/exllamav2/cuda/util.cuh +42 -0
- AutoAWQ_kernels/awq_ext/exllamav2/ext.cpp +134 -0
- AutoAWQ_kernels/awq_ext/layernorm/layernorm.cu +113 -0
- AutoAWQ_kernels/awq_ext/layernorm/layernorm.h +3 -0
- AutoAWQ_kernels/awq_ext/layernorm/reduction.cuh +82 -0
AutoAWQ_kernels/.github/workflows/build.yaml
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Build AutoAWQ Wheels with CUDA
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
tags:
|
6 |
+
- "v*"
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
release:
|
10 |
+
# Retrieve tag and create release
|
11 |
+
name: Create Release
|
12 |
+
runs-on: ubuntu-latest
|
13 |
+
outputs:
|
14 |
+
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
15 |
+
steps:
|
16 |
+
- name: Checkout
|
17 |
+
uses: actions/checkout@v3
|
18 |
+
|
19 |
+
- name: Extract branch info
|
20 |
+
shell: bash
|
21 |
+
run: |
|
22 |
+
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
23 |
+
|
24 |
+
- name: Create Release
|
25 |
+
id: create_release
|
26 |
+
uses: "actions/github-script@v6"
|
27 |
+
env:
|
28 |
+
RELEASE_TAG: ${{ env.release_tag }}
|
29 |
+
with:
|
30 |
+
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
31 |
+
script: |
|
32 |
+
const script = require('.github/workflows/scripts/github_create_release.js')
|
33 |
+
await script(github, context, core)
|
34 |
+
|
35 |
+
build_cuda_wheels:
|
36 |
+
name: Build AWQ with CUDA
|
37 |
+
runs-on: ${{ matrix.os }}
|
38 |
+
needs: release
|
39 |
+
|
40 |
+
strategy:
|
41 |
+
matrix:
|
42 |
+
os: [ubuntu-20.04, windows-latest]
|
43 |
+
pyver: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
44 |
+
cuda: ["11.8.0", "12.1.1"]
|
45 |
+
defaults:
|
46 |
+
run:
|
47 |
+
shell: pwsh
|
48 |
+
env:
|
49 |
+
PYPI_CUDA_VERSION: "12.1.1"
|
50 |
+
CUDA_VERSION: ${{ matrix.cuda }}
|
51 |
+
|
52 |
+
steps:
|
53 |
+
- name: Free Disk Space
|
54 |
+
uses: jlumbroso/[email protected]
|
55 |
+
if: runner.os == 'Linux'
|
56 |
+
with:
|
57 |
+
tool-cache: false
|
58 |
+
android: true
|
59 |
+
dotnet: true
|
60 |
+
haskell: true
|
61 |
+
large-packages: false
|
62 |
+
docker-images: true
|
63 |
+
swap-storage: false
|
64 |
+
|
65 |
+
- uses: actions/checkout@v3
|
66 |
+
|
67 |
+
- uses: actions/setup-python@v3
|
68 |
+
with:
|
69 |
+
python-version: ${{ matrix.pyver }}
|
70 |
+
|
71 |
+
- name: Setup Mamba
|
72 |
+
uses: conda-incubator/[email protected]
|
73 |
+
with:
|
74 |
+
activate-environment: "build"
|
75 |
+
python-version: ${{ matrix.pyver }}
|
76 |
+
miniforge-variant: Mambaforge
|
77 |
+
miniforge-version: latest
|
78 |
+
use-mamba: true
|
79 |
+
add-pip-as-python-dependency: true
|
80 |
+
auto-activate-base: false
|
81 |
+
|
82 |
+
- name: Install Dependencies
|
83 |
+
run: |
|
84 |
+
# Install CUDA toolkit
|
85 |
+
mamba install -y 'cuda' -c "nvidia/label/cuda-${env:CUDA_VERSION}"
|
86 |
+
|
87 |
+
# Env variables
|
88 |
+
$env:CUDA_PATH = $env:CONDA_PREFIX
|
89 |
+
$env:CUDA_HOME = $env:CONDA_PREFIX
|
90 |
+
|
91 |
+
# Install torch
|
92 |
+
$cudaVersion = $env:CUDA_VERSION.Replace('.', '')
|
93 |
+
$cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1)
|
94 |
+
$pytorchVersion = "torch==2.3.1"
|
95 |
+
python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch
|
96 |
+
python -m pip install build setuptools wheel ninja
|
97 |
+
|
98 |
+
# Print version information
|
99 |
+
python --version
|
100 |
+
python -c "import torch; print('PyTorch:', torch.__version__)"
|
101 |
+
python -c "import torch; print('CUDA:', torch.version.cuda)"
|
102 |
+
python -c "import os; print('CUDA_HOME:', os.getenv('CUDA_HOME', None))"
|
103 |
+
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
104 |
+
|
105 |
+
- name: Build Wheel
|
106 |
+
run: |
|
107 |
+
$env:CUDA_PATH = $env:CONDA_PREFIX
|
108 |
+
$env:CUDA_HOME = $env:CONDA_PREFIX
|
109 |
+
|
110 |
+
# Only add +cu118 to wheel if not releasing on PyPi
|
111 |
+
if ( $env:CUDA_VERSION -eq $env:PYPI_CUDA_VERSION ){
|
112 |
+
$env:PYPI_BUILD = 1
|
113 |
+
}
|
114 |
+
|
115 |
+
python setup.py sdist bdist_wheel
|
116 |
+
|
117 |
+
- name: Upload Assets
|
118 |
+
uses: shogo82148/actions-upload-release-asset@v1
|
119 |
+
with:
|
120 |
+
upload_url: ${{ needs.release.outputs.upload_url }}
|
121 |
+
asset_path: ./dist/*.whl
|
122 |
+
|
123 |
+
build_rocm_wheels:
|
124 |
+
name: Build AWQ with ROCm
|
125 |
+
runs-on: ${{ matrix.os }}
|
126 |
+
needs: release
|
127 |
+
|
128 |
+
strategy:
|
129 |
+
matrix:
|
130 |
+
os: [ubuntu-20.04]
|
131 |
+
python: ["3.8", "3.9", "3.10", "3.11"]
|
132 |
+
rocm: ["5.6.1", "5.7.1"] # we build only for rocm5.6 & 5.7 to match PyTorch 2.1.0 and PyTorch 2.2 nightly
|
133 |
+
defaults:
|
134 |
+
run:
|
135 |
+
shell: bash
|
136 |
+
env:
|
137 |
+
ROCM_VERSION: ${{ matrix.rocm }}
|
138 |
+
|
139 |
+
steps:
|
140 |
+
- uses: actions/checkout@v3
|
141 |
+
|
142 |
+
- name: Free Disk Space
|
143 |
+
run: |
|
144 |
+
df -h
|
145 |
+
echo "Removing large packages"
|
146 |
+
sudo apt-get remove -y '^dotnet-.*'
|
147 |
+
sudo apt-get remove -y 'php.*'
|
148 |
+
sudo apt-get remove -y azure-cli google-chrome-stable firefox powershell mono-devel
|
149 |
+
df -h
|
150 |
+
sudo apt-get autoremove -y >/dev/null 2>&1
|
151 |
+
sudo apt-get clean
|
152 |
+
sudo apt-get autoremove -y >/dev/null 2>&1
|
153 |
+
sudo apt-get autoclean -y >/dev/null 2>&1
|
154 |
+
df -h
|
155 |
+
echo "https://github.com/actions/virtual-environments/issues/709"
|
156 |
+
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
157 |
+
df -h
|
158 |
+
echo "remove big /usr/local"
|
159 |
+
sudo rm -rf "/usr/local/share/boost"
|
160 |
+
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
|
161 |
+
df -h
|
162 |
+
sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1
|
163 |
+
sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1
|
164 |
+
sudo rm -rf /usr/share/swift > /dev/null 2>&1
|
165 |
+
df -h
|
166 |
+
|
167 |
+
- uses: actions/setup-python@v3
|
168 |
+
with:
|
169 |
+
python-version: ${{ matrix.python }}
|
170 |
+
|
171 |
+
- name: Setup Mamba
|
172 |
+
uses: conda-incubator/[email protected]
|
173 |
+
with:
|
174 |
+
activate-environment: "build"
|
175 |
+
python-version: ${{ matrix.python }}
|
176 |
+
mamba-version: "*"
|
177 |
+
use-mamba: false
|
178 |
+
channels: conda-forge,defaults
|
179 |
+
channel-priority: true
|
180 |
+
add-pip-as-python-dependency: true
|
181 |
+
auto-activate-base: false
|
182 |
+
|
183 |
+
- name: Set up ROCm
|
184 |
+
run: |
|
185 |
+
echo "Using python:"
|
186 |
+
python --version
|
187 |
+
which python
|
188 |
+
|
189 |
+
if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then
|
190 |
+
export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb
|
191 |
+
elif [[ "${{ matrix.rocm }}" == "5.6.1" ]]; then
|
192 |
+
export ROCM_DL_FILE=amdgpu-install_5.6.50601-1_all.deb
|
193 |
+
elif [[ "${{ matrix.rocm }}" == "5.7.1" ]]; then
|
194 |
+
export ROCM_DL_FILE=amdgpu-install_5.7.50701-1_all.deb
|
195 |
+
else
|
196 |
+
echo Unknown rocm version
|
197 |
+
exit 1
|
198 |
+
fi
|
199 |
+
|
200 |
+
curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/focal/$ROCM_DL_FILE
|
201 |
+
sudo dpkg -i $ROCM_DL_FILE
|
202 |
+
sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y
|
203 |
+
|
204 |
+
- name: Install Dependencies
|
205 |
+
run: |
|
206 |
+
sudo apt-get update
|
207 |
+
sudo apt-get install -y --no-install-recommends rocsparse-dev rocthrust-dev rocblas-dev hipblas-dev hipsparse-dev
|
208 |
+
|
209 |
+
python -m pip install --upgrade build setuptools wheel
|
210 |
+
|
211 |
+
if [[ "${{ matrix.rocm }}" == "5.7.1" ]]; then
|
212 |
+
python -m pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/rocm5.7
|
213 |
+
elif [[ "${{ matrix.rocm }}" == "5.6.1" ]]; then
|
214 |
+
python -m pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/rocm5.6
|
215 |
+
else
|
216 |
+
echo Unknown rocm version for python install
|
217 |
+
exit 1
|
218 |
+
fi
|
219 |
+
|
220 |
+
- name: Build Wheel
|
221 |
+
run: |
|
222 |
+
echo "Using python for build:"
|
223 |
+
python --version
|
224 |
+
which python
|
225 |
+
|
226 |
+
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
|
227 |
+
|
228 |
+
- name: Upload Assets
|
229 |
+
uses: shogo82148/actions-upload-release-asset@v1
|
230 |
+
with:
|
231 |
+
upload_url: ${{ needs.release.outputs.upload_url }}
|
232 |
+
asset_path: ./dist/*.whl
|
AutoAWQ_kernels/.github/workflows/scripts/github_create_release.js
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
module.exports = async (github, context, core) => {
|
2 |
+
try {
|
3 |
+
const response = await github.rest.repos.createRelease({
|
4 |
+
draft: false,
|
5 |
+
generate_release_notes: true,
|
6 |
+
name: process.env.RELEASE_TAG,
|
7 |
+
owner: context.repo.owner,
|
8 |
+
prerelease: false,
|
9 |
+
repo: context.repo.repo,
|
10 |
+
tag_name: process.env.RELEASE_TAG,
|
11 |
+
});
|
12 |
+
|
13 |
+
core.setOutput('upload_url', response.data.upload_url);
|
14 |
+
} catch (error) {
|
15 |
+
core.setFailed(error.message);
|
16 |
+
}
|
17 |
+
}
|
AutoAWQ_kernels/.gitignore
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
# Byte-compiled / optimized / DLL files
|
3 |
+
__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
*$py.class
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
cover/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# poetry
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
103 |
+
#poetry.lock
|
104 |
+
|
105 |
+
# pdm
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
107 |
+
#pdm.lock
|
108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
109 |
+
# in version control.
|
110 |
+
# https://pdm.fming.dev/#use-with-ide
|
111 |
+
.pdm.toml
|
112 |
+
|
113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
114 |
+
__pypackages__/
|
115 |
+
|
116 |
+
# Celery stuff
|
117 |
+
celerybeat-schedule
|
118 |
+
celerybeat.pid
|
119 |
+
|
120 |
+
# SageMath parsed files
|
121 |
+
*.sage.py
|
122 |
+
|
123 |
+
# Environments
|
124 |
+
.env
|
125 |
+
.venv
|
126 |
+
env/
|
127 |
+
venv/
|
128 |
+
ENV/
|
129 |
+
env.bak/
|
130 |
+
venv.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
#.idea/
|
162 |
+
|
163 |
+
*hip*
|
164 |
+
!hip_compact.hip
|
AutoAWQ_kernels/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Casper
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
AutoAWQ_kernels/README.md
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AutoAWQ Kernels
|
2 |
+
|
3 |
+
AutoAWQ Kernels is a new package that is split up from the [main repository](https://github.com/casper-hansen/AutoAWQ) in order to avoid compilation times.
|
4 |
+
|
5 |
+
## Requirements
|
6 |
+
|
7 |
+
- Windows: Must use WSL2.
|
8 |
+
|
9 |
+
- NVIDIA:
|
10 |
+
- GPU: Must be compute capability 7.5 or higher.
|
11 |
+
- CUDA Toolkit: Must be 11.8 or higher.
|
12 |
+
- AMD:
|
13 |
+
- ROCm: Must be 5.6 or higher.
|
14 |
+
|
15 |
+
## Install
|
16 |
+
|
17 |
+
### Install from PyPi
|
18 |
+
|
19 |
+
The package is available on PyPi with CUDA 12.1.1 wheels:
|
20 |
+
|
21 |
+
```
|
22 |
+
pip install autoawq-kernels
|
23 |
+
```
|
24 |
+
|
25 |
+
### Install release wheels
|
26 |
+
|
27 |
+
For ROCm and other CUDA versions, you can use the wheels published at each [release](https://github.com/casper-hansen/AutoAWQ_kernels/releases/):
|
28 |
+
|
29 |
+
```
|
30 |
+
pip install https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.2/autoawq_kernels-0.0.2+rocm561-cp310-cp310-linux_x86_64.whl
|
31 |
+
```
|
32 |
+
|
33 |
+
### Build from source
|
34 |
+
You can also build from source:
|
35 |
+
|
36 |
+
```
|
37 |
+
git clone https://github.com/casper-hansen/AutoAWQ_kernels
|
38 |
+
cd AutoAWQ_kernels
|
39 |
+
pip install -e .
|
40 |
+
```
|
41 |
+
|
42 |
+
To build for ROCm, you need to first install the following packages `rocsparse-dev hipsparse-dev rocthrust-dev rocblas-dev hipblas-dev`.
|
AutoAWQ_kernels/awq_ext/attention/cuda_bf16_fallbacks.cuh
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "cuda_bf16_wrapper.h"
|
22 |
+
#include <cuda_fp16.h>
|
23 |
+
|
24 |
+
namespace fastertransformer {
|
25 |
+
|
26 |
+
#ifdef ENABLE_BF16
|
27 |
+
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
28 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
29 |
+
float2 f_val;
|
30 |
+
f_val.x = __low2float(val);
|
31 |
+
f_val.y = __high2float(val);
|
32 |
+
return f_val;
|
33 |
+
#else
|
34 |
+
return __bfloat1622float2(val);
|
35 |
+
#endif
|
36 |
+
}
|
37 |
+
|
38 |
+
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
|
39 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
40 |
+
float2 f_val;
|
41 |
+
f_val.x = max(min(__low2float(val), 127.f), -128.f);
|
42 |
+
f_val.y = max(min(__high2float(val), 127.f), -128.f);
|
43 |
+
union { int8_t int8[2]; int16_t int16; };
|
44 |
+
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
|
45 |
+
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
|
46 |
+
return int16;
|
47 |
+
#else
|
48 |
+
val = __hmin2(val, make_bfloat162(127., 127.));
|
49 |
+
val = __hmax2(val, make_bfloat162(-128., -128.));
|
50 |
+
union { int8_t int8[2]; int16_t int16; };
|
51 |
+
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
|
52 |
+
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
|
53 |
+
return int16;
|
54 |
+
#endif
|
55 |
+
}
|
56 |
+
|
57 |
+
inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
|
58 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
59 |
+
return __floats2bfloat162_rn(val.x, val.y);
|
60 |
+
#else
|
61 |
+
return __float22bfloat162_rn(val);
|
62 |
+
#endif
|
63 |
+
}
|
64 |
+
|
65 |
+
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
66 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
67 |
+
__nv_bfloat162 val2;
|
68 |
+
val2.x = val;
|
69 |
+
val2.y = val;
|
70 |
+
return val2;
|
71 |
+
#else
|
72 |
+
return __bfloat162bfloat162(val);
|
73 |
+
#endif
|
74 |
+
}
|
75 |
+
|
76 |
+
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
|
77 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
78 |
+
float fxl, fxh, fyl, fyh;
|
79 |
+
fxl = __low2float(x);
|
80 |
+
fxh = __high2float(x);
|
81 |
+
fyl = __low2float(y);
|
82 |
+
fyh = __high2float(y);
|
83 |
+
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
|
84 |
+
#else
|
85 |
+
return __hadd2(x, y);
|
86 |
+
#endif
|
87 |
+
}
|
88 |
+
|
89 |
+
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
|
90 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
91 |
+
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
|
92 |
+
#else
|
93 |
+
return __hadd(x, y);
|
94 |
+
#endif
|
95 |
+
}
|
96 |
+
|
97 |
+
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
|
98 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
99 |
+
float fxl, fxh, fyl, fyh;
|
100 |
+
fxl = __low2float(x);
|
101 |
+
fxh = __high2float(x);
|
102 |
+
fyl = __low2float(y);
|
103 |
+
fyh = __high2float(y);
|
104 |
+
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
|
105 |
+
#else
|
106 |
+
return __hsub2(x, y);
|
107 |
+
#endif
|
108 |
+
}
|
109 |
+
|
110 |
+
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
|
111 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
112 |
+
return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
|
113 |
+
#else
|
114 |
+
return __hsub(x, y);
|
115 |
+
#endif
|
116 |
+
}
|
117 |
+
|
118 |
+
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
|
119 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
120 |
+
float fxl, fxh, fyl, fyh;
|
121 |
+
fxl = __low2float(x);
|
122 |
+
fxh = __high2float(x);
|
123 |
+
fyl = __low2float(y);
|
124 |
+
fyh = __high2float(y);
|
125 |
+
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
|
126 |
+
#else
|
127 |
+
return __hmul2(x, y);
|
128 |
+
#endif
|
129 |
+
}
|
130 |
+
|
131 |
+
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
|
132 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
133 |
+
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
|
134 |
+
#else
|
135 |
+
return __hmul(x, y);
|
136 |
+
#endif
|
137 |
+
}
|
138 |
+
|
139 |
+
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
|
140 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
141 |
+
float fxl, fxh, fyl, fyh, fzl, fzh;
|
142 |
+
fxl = __low2float(x);
|
143 |
+
fxh = __high2float(x);
|
144 |
+
fyl = __low2float(y);
|
145 |
+
fyh = __high2float(y);
|
146 |
+
fzl = __low2float(z);
|
147 |
+
fzh = __high2float(z);
|
148 |
+
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
|
149 |
+
#else
|
150 |
+
return __hfma2(x, y, z);
|
151 |
+
#endif
|
152 |
+
}
|
153 |
+
|
154 |
+
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
|
155 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
156 |
+
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
|
157 |
+
#else
|
158 |
+
return __hfma(x, y, z);
|
159 |
+
#endif
|
160 |
+
}
|
161 |
+
|
162 |
+
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
|
163 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
164 |
+
float fxl, fxh;
|
165 |
+
fxl = __low2float(x);
|
166 |
+
fxh = __high2float(x);;
|
167 |
+
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
|
168 |
+
#else
|
169 |
+
return h2exp(x);
|
170 |
+
#endif
|
171 |
+
}
|
172 |
+
|
173 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
174 |
+
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
|
175 |
+
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
|
176 |
+
|
177 |
+
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
178 |
+
{
|
179 |
+
__nv_bfloat162 t; t.x = x; t.y = y; return t;
|
180 |
+
}
|
181 |
+
|
182 |
+
#endif
|
183 |
+
|
184 |
+
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
|
185 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
186 |
+
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
|
187 |
+
#else
|
188 |
+
return a + b + c;
|
189 |
+
#endif
|
190 |
+
}
|
191 |
+
|
192 |
+
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
|
193 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
194 |
+
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
|
195 |
+
#else
|
196 |
+
return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
|
197 |
+
#endif
|
198 |
+
}
|
199 |
+
|
200 |
+
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
201 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
202 |
+
float fal, fah, fbl, fbh, fcl, fch;
|
203 |
+
fal = __low2float(a);
|
204 |
+
fah = __high2float(a);
|
205 |
+
fbl = __low2float(b);
|
206 |
+
fbh = __high2float(b);
|
207 |
+
fcl = __low2float(c);
|
208 |
+
fch = __high2float(c);
|
209 |
+
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
|
210 |
+
#else
|
211 |
+
return a + b + c;
|
212 |
+
#endif
|
213 |
+
}
|
214 |
+
|
215 |
+
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
|
216 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
217 |
+
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
|
218 |
+
#else
|
219 |
+
return a * b * c;
|
220 |
+
#endif
|
221 |
+
}
|
222 |
+
|
223 |
+
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
224 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
225 |
+
float fal, fah, fbl, fbh, fcl, fch;
|
226 |
+
fal = __low2float(a);
|
227 |
+
fah = __high2float(a);
|
228 |
+
fbl = __low2float(b);
|
229 |
+
fbh = __high2float(b);
|
230 |
+
fcl = __low2float(c);
|
231 |
+
fch = __high2float(c);
|
232 |
+
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
|
233 |
+
#else
|
234 |
+
return a * b * c;
|
235 |
+
#endif
|
236 |
+
}
|
237 |
+
|
238 |
+
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
|
239 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
240 |
+
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
|
241 |
+
fal = __low2float(a);
|
242 |
+
fah = __high2float(a);
|
243 |
+
fbl = __low2float(b);
|
244 |
+
fbh = __high2float(b);
|
245 |
+
fcl = __low2float(c);
|
246 |
+
fch = __high2float(c);
|
247 |
+
fdl = __low2float(d);
|
248 |
+
fdh = __high2float(d);
|
249 |
+
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
|
250 |
+
#else
|
251 |
+
return a * b * c + d;
|
252 |
+
#endif
|
253 |
+
}
|
254 |
+
|
255 |
+
#endif // ENABLE_BF16
|
256 |
+
|
257 |
+
} // namespace fastertransformer
|
AutoAWQ_kernels/awq_ext/attention/cuda_bf16_wrapper.h
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#ifdef ENABLE_BF16
|
22 |
+
#include <cuda_bf16.h>
|
23 |
+
#endif
|
AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention.cu
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#include "decoder_masked_multihead_attention.h"
|
20 |
+
#include "decoder_masked_multihead_attention_utils.h"
|
21 |
+
#include "cuda_bf16_wrapper.h"
|
22 |
+
#include <assert.h>
|
23 |
+
#include <float.h>
|
24 |
+
#include <type_traits>
|
25 |
+
|
26 |
+
#include "decoder_masked_multihead_attention_template.hpp"
|
27 |
+
|
28 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
29 |
+
|
30 |
+
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
|
31 |
+
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
|
32 |
+
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
|
33 |
+
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
|
34 |
+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
|
35 |
+
dim3 grid(params.num_heads, params.batch_size); \
|
36 |
+
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
|
37 |
+
|
38 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
39 |
+
|
40 |
+
// !!! Specialize the launcher for Cross attention
|
41 |
+
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
|
42 |
+
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
|
43 |
+
{
|
44 |
+
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
|
45 |
+
constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
|
46 |
+
int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
|
47 |
+
// printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
|
48 |
+
if (tlength < 32) {
|
49 |
+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
|
50 |
+
}
|
51 |
+
else if (tlength < 2048) {
|
52 |
+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
|
53 |
+
}
|
54 |
+
else {
|
55 |
+
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
60 |
+
|
61 |
+
#undef MMHA_LAUNCH_KERNEL
|
62 |
+
|
63 |
+
template<typename T, typename KERNEL_PARAMS_TYPE>
|
64 |
+
void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
|
65 |
+
{
|
66 |
+
switch (params.hidden_size_per_head) {
|
67 |
+
case 32:
|
68 |
+
mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
|
69 |
+
break;
|
70 |
+
case 48:
|
71 |
+
mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
|
72 |
+
break;
|
73 |
+
case 64:
|
74 |
+
mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
|
75 |
+
break;
|
76 |
+
case 80:
|
77 |
+
mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
|
78 |
+
break;
|
79 |
+
case 96:
|
80 |
+
mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
|
81 |
+
break;
|
82 |
+
case 112:
|
83 |
+
mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
|
84 |
+
break;
|
85 |
+
case 128:
|
86 |
+
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
|
87 |
+
break;
|
88 |
+
case 160:
|
89 |
+
mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
|
90 |
+
break;
|
91 |
+
case 192:
|
92 |
+
mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
|
93 |
+
break;
|
94 |
+
case 224:
|
95 |
+
mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
|
96 |
+
break;
|
97 |
+
case 256:
|
98 |
+
mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
|
99 |
+
break;
|
100 |
+
default:
|
101 |
+
assert(false);
|
102 |
+
}
|
103 |
+
}
|
104 |
+
|
105 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
106 |
+
|
107 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
|
108 |
+
{
|
109 |
+
multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
|
110 |
+
}
|
111 |
+
|
112 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
113 |
+
|
114 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
|
115 |
+
{
|
116 |
+
multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
|
117 |
+
}
|
118 |
+
|
119 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
120 |
+
|
121 |
+
#ifdef ENABLE_BF16
|
122 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
|
123 |
+
const cudaStream_t& stream)
|
124 |
+
{
|
125 |
+
multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
|
126 |
+
}
|
127 |
+
#endif
|
128 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
129 |
+
|
130 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
|
131 |
+
{
|
132 |
+
multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
|
133 |
+
}
|
134 |
+
|
135 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
136 |
+
|
137 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
|
138 |
+
{
|
139 |
+
multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
|
140 |
+
}
|
141 |
+
|
142 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
143 |
+
|
144 |
+
#ifdef ENABLE_BF16
|
145 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
|
146 |
+
const cudaStream_t& stream)
|
147 |
+
{
|
148 |
+
multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
|
149 |
+
}
|
150 |
+
#endif
|
151 |
+
|
152 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention.h
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "cuda_bf16_wrapper.h"
|
22 |
+
#include <cuda_fp16.h>
|
23 |
+
#include <cuda_runtime_api.h>
|
24 |
+
#include <stdint.h>
|
25 |
+
#include <stdio.h>
|
26 |
+
#include <stdlib.h>
|
27 |
+
|
28 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
29 |
+
|
30 |
+
#define CHECK_CUDA(call) \
|
31 |
+
do { \
|
32 |
+
cudaError_t status_ = call; \
|
33 |
+
if (status_ != cudaSuccess) { \
|
34 |
+
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
|
35 |
+
exit(1); \
|
36 |
+
} \
|
37 |
+
} while (0)
|
38 |
+
|
39 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
40 |
+
|
41 |
+
// The structure of parameters for the masked multihead attention kernel.
|
42 |
+
//
|
43 |
+
// We use the following terminology to describe the different dimensions.
|
44 |
+
//
|
45 |
+
// B: Batch size (number of sequences),
|
46 |
+
// L: Sequence length,
|
47 |
+
// D: Hidden dimension,
|
48 |
+
// H: Number of heads,
|
49 |
+
// Dh: Hidden dimension per head - Dh = D / H.
|
50 |
+
|
51 |
+
template<typename T>
|
52 |
+
struct Multihead_attention_params_base {
|
53 |
+
|
54 |
+
// The output buffer. Dimensions B x D.
|
55 |
+
T* out = nullptr;
|
56 |
+
|
57 |
+
// The input Qs and the associated bias. Dimensions B x D and D, resp.
|
58 |
+
const T *q = nullptr, *q_bias = nullptr;
|
59 |
+
// The input Ks and the associated bias. Dimensions B x D and D, resp.
|
60 |
+
const T *k = nullptr, *k_bias = nullptr;
|
61 |
+
// The input Vs and the associated bias. Dimensions B x D and D, resp.
|
62 |
+
const T *v = nullptr, *v_bias = nullptr;
|
63 |
+
|
64 |
+
// The cache for the Ks. The size must be at least B x L x D.
|
65 |
+
T* k_cache = nullptr;
|
66 |
+
// The cache for the Vs. The size must be at least B x L x D.
|
67 |
+
T* v_cache = nullptr;
|
68 |
+
// The indirections to use for cache when beam sampling.
|
69 |
+
const int* cache_indir = nullptr;
|
70 |
+
|
71 |
+
// Stride to handle the case when KQV is a single buffer
|
72 |
+
int stride = 0;
|
73 |
+
|
74 |
+
// The batch size.
|
75 |
+
int batch_size = 0;
|
76 |
+
// The beam width
|
77 |
+
int beam_width = 0;
|
78 |
+
// The sequence length.
|
79 |
+
int memory_max_len = 0;
|
80 |
+
// The number of heads (H).
|
81 |
+
int num_heads = 0;
|
82 |
+
// The number of heads for KV cache.
|
83 |
+
int num_kv_heads = 0;
|
84 |
+
// The hidden dimension per head (Dh).
|
85 |
+
int hidden_size_per_head = 0;
|
86 |
+
// The per-head latent space reserved for rotary embeddings.
|
87 |
+
int rotary_embedding_dim = 0;
|
88 |
+
bool neox_rotary_style = false;
|
89 |
+
float rotary_base = 0.0f;
|
90 |
+
// The maximum length of input sentences.
|
91 |
+
int max_input_length = 0;
|
92 |
+
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
|
93 |
+
int timestep = 0;
|
94 |
+
// The current timestep of each sentences (support different timestep for different sentences)
|
95 |
+
|
96 |
+
// The 1.f / sqrt(Dh). Computed on the host.
|
97 |
+
float inv_sqrt_dh = 0.0f;
|
98 |
+
|
99 |
+
// Used when we have some input context like gpt
|
100 |
+
const int* total_padding_tokens = nullptr;
|
101 |
+
|
102 |
+
const bool* masked_tokens = nullptr;
|
103 |
+
const int* prefix_prompt_lengths = nullptr;
|
104 |
+
int max_prefix_prompt_length = 0;
|
105 |
+
|
106 |
+
const T* relative_attention_bias = nullptr;
|
107 |
+
int relative_attention_bias_stride = 0;
|
108 |
+
// The slope per head of linear position bias to attention score (H).
|
109 |
+
const float* linear_bias_slopes = nullptr;
|
110 |
+
|
111 |
+
const T* ia3_key_weights = nullptr;
|
112 |
+
const T* ia3_value_weights = nullptr;
|
113 |
+
const int* ia3_tasks = nullptr;
|
114 |
+
|
115 |
+
const float* qkv_scale_out = nullptr;
|
116 |
+
const float* attention_out_scale = nullptr;
|
117 |
+
int int8_mode = 0;
|
118 |
+
};
|
119 |
+
|
120 |
+
template<typename T, bool CROSS_ATTENTION>
|
121 |
+
struct Multihead_attention_params: public Multihead_attention_params_base<T> {
|
122 |
+
// output cross attentions
|
123 |
+
float* cross_attention_out = nullptr;
|
124 |
+
int max_decoder_seq_len = 0;
|
125 |
+
bool is_return_cross_attentions = false;
|
126 |
+
|
127 |
+
// allows to exist attention eary
|
128 |
+
bool* finished = nullptr;
|
129 |
+
|
130 |
+
// required in case of cross attention
|
131 |
+
// will need it here till if constexpr in c++17
|
132 |
+
int* memory_length_per_sample = nullptr;
|
133 |
+
|
134 |
+
// required in case of masked attention with different length
|
135 |
+
const int* length_per_sample = nullptr;
|
136 |
+
};
|
137 |
+
|
138 |
+
template<typename T>
|
139 |
+
struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
|
140 |
+
// output cross attentions
|
141 |
+
float* cross_attention_out = nullptr;
|
142 |
+
int max_decoder_seq_len = 0;
|
143 |
+
bool is_return_cross_attentions = false;
|
144 |
+
|
145 |
+
// allows to exist attention eary
|
146 |
+
bool* finished = nullptr;
|
147 |
+
|
148 |
+
// required in case of cross attention
|
149 |
+
int* memory_length_per_sample = nullptr;
|
150 |
+
|
151 |
+
// required in case of masked attention with different length
|
152 |
+
const int* length_per_sample = nullptr;
|
153 |
+
};
|
154 |
+
|
155 |
+
template<class T>
|
156 |
+
using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
|
157 |
+
|
158 |
+
template<class T>
|
159 |
+
using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
|
160 |
+
|
161 |
+
template<typename T>
|
162 |
+
struct outputCrossAttentionParam {
|
163 |
+
// max decoder output length
|
164 |
+
int max_decoder_seq_len = 0;
|
165 |
+
T* cross_attention_out = nullptr;
|
166 |
+
bool is_return_cross_attentions = false;
|
167 |
+
};
|
168 |
+
|
169 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
170 |
+
|
171 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
|
172 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
|
173 |
+
#ifdef ENABLE_BF16
|
174 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
|
175 |
+
const cudaStream_t& stream);
|
176 |
+
#endif
|
177 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
|
178 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
|
179 |
+
#ifdef ENABLE_BF16
|
180 |
+
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
|
181 |
+
const cudaStream_t& stream);
|
182 |
+
#endif
|
183 |
+
|
184 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention_template.hpp
ADDED
@@ -0,0 +1,1608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
#pragma once
|
19 |
+
|
20 |
+
#include "decoder_masked_multihead_attention.h"
|
21 |
+
#include "decoder_masked_multihead_attention_utils.h"
|
22 |
+
#include "cuda_bf16_wrapper.h"
|
23 |
+
#include "cuda_bf16_fallbacks.cuh"
|
24 |
+
#include <assert.h>
|
25 |
+
#include <float.h>
|
26 |
+
#include <type_traits>
|
27 |
+
|
28 |
+
// #define MMHA_USE_HMMA_FOR_REDUCTION
|
29 |
+
|
30 |
+
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
|
31 |
+
|
32 |
+
// Does not seem to affect the accuracy that much
|
33 |
+
#define MMHA_USE_FP32_ACUM_FOR_FMA
|
34 |
+
|
35 |
+
// Seems to slightly improve the accuracy
|
36 |
+
#define MMHA_USE_FP32_ACUM_FOR_OUT
|
37 |
+
|
38 |
+
#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
|
39 |
+
// Does not seem to improve the accuracy
|
40 |
+
//#define MMHA_USE_FP32_ACUM_FOR_LOGITS
|
41 |
+
#endif
|
42 |
+
|
43 |
+
namespace mmha {
|
44 |
+
|
45 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
46 |
+
|
47 |
+
//
|
48 |
+
// We use the following terminology to describe the different dimensions.
|
49 |
+
//
|
50 |
+
// B: Batch size (number of sequences),
|
51 |
+
// L: Sequence length,
|
52 |
+
// D: Hidden dimension,
|
53 |
+
// H: Number of heads,
|
54 |
+
// Dh: Hidden dimension per head - Dh = D / H.
|
55 |
+
//
|
56 |
+
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
|
57 |
+
// 64, 128 and 256 threads per block.
|
58 |
+
//
|
59 |
+
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
|
60 |
+
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
|
61 |
+
// cache buffer helps with memory accesses and contains keys with bias.
|
62 |
+
//
|
63 |
+
// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
|
64 |
+
// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
|
65 |
+
// values for x are chosen to create chunks of 16 bytes.
|
66 |
+
//
|
67 |
+
// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
|
68 |
+
// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
|
69 |
+
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
|
70 |
+
// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
|
71 |
+
//
|
72 |
+
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
|
73 |
+
// shared memory.
|
74 |
+
//
|
75 |
+
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
|
76 |
+
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
|
77 |
+
// except for the current timestep. The layout of the cache buffer for the values is much simpler
|
78 |
+
// as it is [B, H, L, Dh].
|
79 |
+
//
|
80 |
+
|
81 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
82 |
+
|
83 |
+
template<typename T, int Dh>
|
84 |
+
struct Qk_vec_ {
|
85 |
+
};
|
86 |
+
|
87 |
+
template<>
|
88 |
+
struct Qk_vec_<float, 32> {
|
89 |
+
using Type = float;
|
90 |
+
};
|
91 |
+
template<>
|
92 |
+
struct Qk_vec_<float, 64> {
|
93 |
+
using Type = float2;
|
94 |
+
};
|
95 |
+
template<>
|
96 |
+
struct Qk_vec_<float, 128> {
|
97 |
+
using Type = float4;
|
98 |
+
};
|
99 |
+
template<>
|
100 |
+
struct Qk_vec_<float, 256> {
|
101 |
+
using Type = float4;
|
102 |
+
};
|
103 |
+
template<>
|
104 |
+
struct Qk_vec_<uint16_t, 32> {
|
105 |
+
using Type = uint32_t;
|
106 |
+
};
|
107 |
+
template<>
|
108 |
+
struct Qk_vec_<uint16_t, 64> {
|
109 |
+
using Type = uint32_t;
|
110 |
+
};
|
111 |
+
template<>
|
112 |
+
struct Qk_vec_<uint16_t, 128> {
|
113 |
+
using Type = uint2;
|
114 |
+
};
|
115 |
+
template<>
|
116 |
+
struct Qk_vec_<uint16_t, 256> {
|
117 |
+
using Type = uint4;
|
118 |
+
};
|
119 |
+
#ifdef ENABLE_BF16
|
120 |
+
template<>
|
121 |
+
struct Qk_vec_<__nv_bfloat16, 32> {
|
122 |
+
using Type = __nv_bfloat162;
|
123 |
+
};
|
124 |
+
template<>
|
125 |
+
struct Qk_vec_<__nv_bfloat16, 64> {
|
126 |
+
using Type = __nv_bfloat162;
|
127 |
+
};
|
128 |
+
template<>
|
129 |
+
struct Qk_vec_<__nv_bfloat16, 128> {
|
130 |
+
using Type = bf16_4_t;
|
131 |
+
};
|
132 |
+
template<>
|
133 |
+
struct Qk_vec_<__nv_bfloat16, 256> {
|
134 |
+
using Type = bf16_8_t;
|
135 |
+
};
|
136 |
+
#endif // ENABLE_BF16
|
137 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
138 |
+
|
139 |
+
template<typename T, int THREADS_PER_KEY>
|
140 |
+
struct K_vec_ {
|
141 |
+
};
|
142 |
+
|
143 |
+
template<>
|
144 |
+
struct K_vec_<float, 4> {
|
145 |
+
using Type = float;
|
146 |
+
};
|
147 |
+
template<>
|
148 |
+
struct K_vec_<float, 2> {
|
149 |
+
using Type = float2;
|
150 |
+
};
|
151 |
+
template<>
|
152 |
+
struct K_vec_<float, 1> {
|
153 |
+
using Type = float4;
|
154 |
+
};
|
155 |
+
template<>
|
156 |
+
struct K_vec_<uint16_t, 4> {
|
157 |
+
using Type = uint32_t;
|
158 |
+
};
|
159 |
+
template<>
|
160 |
+
struct K_vec_<uint16_t, 2> {
|
161 |
+
using Type = uint2;
|
162 |
+
};
|
163 |
+
template<>
|
164 |
+
struct K_vec_<uint16_t, 1> {
|
165 |
+
using Type = uint4;
|
166 |
+
};
|
167 |
+
#ifdef ENABLE_BF16
|
168 |
+
template<>
|
169 |
+
struct K_vec_<__nv_bfloat16, 4> {
|
170 |
+
using Type = __nv_bfloat162;
|
171 |
+
};
|
172 |
+
template<>
|
173 |
+
struct K_vec_<__nv_bfloat16, 2> {
|
174 |
+
using Type = bf16_4_t;
|
175 |
+
};
|
176 |
+
template<>
|
177 |
+
struct K_vec_<__nv_bfloat16, 1> {
|
178 |
+
using Type = bf16_8_t;
|
179 |
+
};
|
180 |
+
#endif // ENABLE_BF16
|
181 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
182 |
+
|
183 |
+
template<typename T, int V_VEC_SIZE>
|
184 |
+
struct V_vec_ {
|
185 |
+
};
|
186 |
+
|
187 |
+
template<>
|
188 |
+
struct V_vec_<float, 1> {
|
189 |
+
using Type = float;
|
190 |
+
};
|
191 |
+
template<>
|
192 |
+
struct V_vec_<float, 2> {
|
193 |
+
using Type = float2;
|
194 |
+
};
|
195 |
+
template<>
|
196 |
+
struct V_vec_<float, 4> {
|
197 |
+
using Type = float4;
|
198 |
+
};
|
199 |
+
template<>
|
200 |
+
struct V_vec_<uint16_t, 2> {
|
201 |
+
using Type = uint32_t;
|
202 |
+
};
|
203 |
+
template<>
|
204 |
+
struct V_vec_<uint16_t, 4> {
|
205 |
+
using Type = uint2;
|
206 |
+
};
|
207 |
+
template<>
|
208 |
+
struct V_vec_<uint16_t, 8> {
|
209 |
+
using Type = uint4;
|
210 |
+
};
|
211 |
+
#ifdef ENABLE_BF16
|
212 |
+
template<>
|
213 |
+
struct V_vec_<__nv_bfloat16, 2> {
|
214 |
+
using Type = __nv_bfloat162;
|
215 |
+
};
|
216 |
+
template<>
|
217 |
+
struct V_vec_<__nv_bfloat16, 4> {
|
218 |
+
using Type = bf16_4_t;
|
219 |
+
};
|
220 |
+
template<>
|
221 |
+
struct V_vec_<__nv_bfloat16, 8> {
|
222 |
+
using Type = bf16_8_t;
|
223 |
+
};
|
224 |
+
#endif // ENABLE_BF16
|
225 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
226 |
+
|
227 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
228 |
+
template<typename T>
|
229 |
+
struct Qk_vec_acum_fp32_ {
|
230 |
+
};
|
231 |
+
|
232 |
+
template<>
|
233 |
+
struct Qk_vec_acum_fp32_<float> {
|
234 |
+
using Type = float;
|
235 |
+
};
|
236 |
+
template<>
|
237 |
+
struct Qk_vec_acum_fp32_<float2> {
|
238 |
+
using Type = float2;
|
239 |
+
};
|
240 |
+
template<>
|
241 |
+
struct Qk_vec_acum_fp32_<float4> {
|
242 |
+
using Type = float4;
|
243 |
+
};
|
244 |
+
// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
|
245 |
+
template<>
|
246 |
+
struct Qk_vec_acum_fp32_<uint32_t> {
|
247 |
+
using Type = float2;
|
248 |
+
};
|
249 |
+
template<>
|
250 |
+
struct Qk_vec_acum_fp32_<uint2> {
|
251 |
+
using Type = Float4_;
|
252 |
+
};
|
253 |
+
template<>
|
254 |
+
struct Qk_vec_acum_fp32_<uint4> {
|
255 |
+
using Type = Float8_;
|
256 |
+
};
|
257 |
+
template<>
|
258 |
+
struct Qk_vec_acum_fp32_<__nv_bfloat16> {
|
259 |
+
using Type = float;
|
260 |
+
};
|
261 |
+
template<>
|
262 |
+
struct Qk_vec_acum_fp32_<__nv_bfloat162> {
|
263 |
+
using Type = float2;
|
264 |
+
};
|
265 |
+
template<>
|
266 |
+
struct Qk_vec_acum_fp32_<bf16_4_t> {
|
267 |
+
using Type = Float4_;
|
268 |
+
};
|
269 |
+
template<>
|
270 |
+
struct Qk_vec_acum_fp32_<bf16_8_t> {
|
271 |
+
using Type = Float8_;
|
272 |
+
};
|
273 |
+
|
274 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
275 |
+
|
276 |
+
template<typename T>
|
277 |
+
struct K_vec_acum_fp32_ {
|
278 |
+
};
|
279 |
+
|
280 |
+
template<>
|
281 |
+
struct K_vec_acum_fp32_<float> {
|
282 |
+
using Type = float;
|
283 |
+
};
|
284 |
+
template<>
|
285 |
+
struct K_vec_acum_fp32_<float2> {
|
286 |
+
using Type = float2;
|
287 |
+
};
|
288 |
+
template<>
|
289 |
+
struct K_vec_acum_fp32_<float4> {
|
290 |
+
using Type = float4;
|
291 |
+
};
|
292 |
+
template<>
|
293 |
+
struct K_vec_acum_fp32_<uint32_t> {
|
294 |
+
using Type = float2;
|
295 |
+
};
|
296 |
+
template<>
|
297 |
+
struct K_vec_acum_fp32_<uint2> {
|
298 |
+
using Type = Float4_;
|
299 |
+
};
|
300 |
+
template<>
|
301 |
+
struct K_vec_acum_fp32_<uint4> {
|
302 |
+
using Type = Float8_;
|
303 |
+
};
|
304 |
+
template<>
|
305 |
+
struct K_vec_acum_fp32_<__nv_bfloat16> {
|
306 |
+
using Type = float;
|
307 |
+
};
|
308 |
+
template<>
|
309 |
+
struct K_vec_acum_fp32_<__nv_bfloat162> {
|
310 |
+
using Type = float2;
|
311 |
+
};
|
312 |
+
template<>
|
313 |
+
struct K_vec_acum_fp32_<bf16_4_t> {
|
314 |
+
using Type = Float4_;
|
315 |
+
};
|
316 |
+
template<>
|
317 |
+
struct K_vec_acum_fp32_<bf16_8_t> {
|
318 |
+
using Type = Float8_;
|
319 |
+
};
|
320 |
+
#endif
|
321 |
+
|
322 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
323 |
+
|
324 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
325 |
+
template<typename T>
|
326 |
+
struct V_vec_acum_fp32_ {
|
327 |
+
};
|
328 |
+
|
329 |
+
template<>
|
330 |
+
struct V_vec_acum_fp32_<float> {
|
331 |
+
using Type = float;
|
332 |
+
};
|
333 |
+
template<>
|
334 |
+
struct V_vec_acum_fp32_<float2> {
|
335 |
+
using Type = float2;
|
336 |
+
};
|
337 |
+
template<>
|
338 |
+
struct V_vec_acum_fp32_<float4> {
|
339 |
+
using Type = float4;
|
340 |
+
};
|
341 |
+
template<>
|
342 |
+
struct V_vec_acum_fp32_<uint32_t> {
|
343 |
+
using Type = float2;
|
344 |
+
};
|
345 |
+
template<>
|
346 |
+
struct V_vec_acum_fp32_<uint2> {
|
347 |
+
using Type = Float4_;
|
348 |
+
};
|
349 |
+
template<>
|
350 |
+
struct V_vec_acum_fp32_<uint4> {
|
351 |
+
using Type = Float8_;
|
352 |
+
};
|
353 |
+
#ifdef ENABLE_BF16
|
354 |
+
template<>
|
355 |
+
struct V_vec_acum_fp32_<__nv_bfloat162> {
|
356 |
+
using Type = float2;
|
357 |
+
};
|
358 |
+
template<>
|
359 |
+
struct V_vec_acum_fp32_<bf16_4_t> {
|
360 |
+
using Type = Float4_;
|
361 |
+
};
|
362 |
+
template<>
|
363 |
+
struct V_vec_acum_fp32_<bf16_8_t> {
|
364 |
+
using Type = Float8_;
|
365 |
+
};
|
366 |
+
#endif // ENABLE_BF16
|
367 |
+
#endif
|
368 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
369 |
+
|
370 |
+
template<int THREADS_PER_KEY, typename K_vec, int N>
|
371 |
+
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
|
372 |
+
{
|
373 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
374 |
+
using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
|
375 |
+
#else
|
376 |
+
using K_vec_acum = K_vec;
|
377 |
+
#endif
|
378 |
+
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
379 |
+
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
|
380 |
+
#pragma unroll
|
381 |
+
for (int ii = 1; ii < N; ++ii) {
|
382 |
+
qk_vec = fma(q[ii], k[ii], qk_vec);
|
383 |
+
}
|
384 |
+
|
385 |
+
// Finalize the reduction across lanes.
|
386 |
+
float qk = sum(qk_vec);
|
387 |
+
#pragma unroll
|
388 |
+
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
|
389 |
+
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
390 |
+
}
|
391 |
+
return qk;
|
392 |
+
}
|
393 |
+
|
394 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
395 |
+
|
396 |
+
template<typename T, int THREADS_PER_KEY>
|
397 |
+
struct Qk_dot {
|
398 |
+
template<typename K_vec, int N>
|
399 |
+
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
|
400 |
+
{
|
401 |
+
return qk_dot_<THREADS_PER_KEY>(q, k);
|
402 |
+
}
|
403 |
+
};
|
404 |
+
|
405 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
406 |
+
|
407 |
+
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
|
408 |
+
{
|
409 |
+
float4 c;
|
410 |
+
float zero = 0.f;
|
411 |
+
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
|
412 |
+
" {%0, %1, %2, %3}, \n"
|
413 |
+
" {%4, %5}, \n"
|
414 |
+
" {%6}, \n"
|
415 |
+
" {%7, %7, %7, %7}; \n"
|
416 |
+
|
417 |
+
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
|
418 |
+
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
|
419 |
+
return c;
|
420 |
+
}
|
421 |
+
|
422 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
423 |
+
|
424 |
+
template<int N>
|
425 |
+
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
426 |
+
{
|
427 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
428 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
429 |
+
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
|
430 |
+
#else
|
431 |
+
using K_vec_acum = uint32_t;
|
432 |
+
#endif
|
433 |
+
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
|
434 |
+
#pragma unroll
|
435 |
+
for (int ii = 1; ii < N; ++ii) {
|
436 |
+
qk_vec = fma(q[ii], k[ii], qk_vec);
|
437 |
+
}
|
438 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
439 |
+
uint32_t qk_vec_ = float2_to_half2(qk_vec);
|
440 |
+
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
|
441 |
+
#else
|
442 |
+
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
|
443 |
+
#endif
|
444 |
+
#else
|
445 |
+
return 0.f;
|
446 |
+
#endif
|
447 |
+
}
|
448 |
+
|
449 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
450 |
+
|
451 |
+
template<>
|
452 |
+
struct Qk_dot<uint16_t, 4> {
|
453 |
+
template<int N>
|
454 |
+
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
455 |
+
{
|
456 |
+
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
|
457 |
+
return qk_hmma_dot_(q, k);
|
458 |
+
#else
|
459 |
+
return qk_dot_<4>(q, k);
|
460 |
+
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
|
461 |
+
}
|
462 |
+
};
|
463 |
+
|
464 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
465 |
+
|
466 |
+
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
467 |
+
inline __device__ float block_sum(float* red_smem, float sum)
|
468 |
+
{
|
469 |
+
|
470 |
+
// Decompose the thread index into warp / lane.
|
471 |
+
int warp = threadIdx.x / WARP_SIZE;
|
472 |
+
int lane = threadIdx.x % WARP_SIZE;
|
473 |
+
|
474 |
+
// Compute the sum per warp.
|
475 |
+
#pragma unroll
|
476 |
+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
477 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
478 |
+
}
|
479 |
+
|
480 |
+
// Warp leaders store the data to shared memory.
|
481 |
+
if (lane == 0) {
|
482 |
+
red_smem[warp] = sum;
|
483 |
+
}
|
484 |
+
|
485 |
+
// Make sure the data is in shared memory.
|
486 |
+
__syncthreads();
|
487 |
+
|
488 |
+
// The warps compute the final sums.
|
489 |
+
if (lane < WARPS_PER_BLOCK) {
|
490 |
+
sum = red_smem[lane];
|
491 |
+
}
|
492 |
+
|
493 |
+
// Parallel reduction inside the warp.
|
494 |
+
#pragma unroll
|
495 |
+
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
496 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
497 |
+
}
|
498 |
+
|
499 |
+
// Broadcast to other threads.
|
500 |
+
return __shfl_sync(uint32_t(-1), sum, 0);
|
501 |
+
}
|
502 |
+
|
503 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
504 |
+
|
505 |
+
inline __device__ void convert_from_float(float& dst, float src)
|
506 |
+
{
|
507 |
+
dst = src;
|
508 |
+
}
|
509 |
+
|
510 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
511 |
+
|
512 |
+
inline __device__ void convert_from_float(uint16_t& dst, float src)
|
513 |
+
{
|
514 |
+
dst = float_to_half(src);
|
515 |
+
}
|
516 |
+
|
517 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
518 |
+
|
519 |
+
inline __device__ void convert_from_float(uint32_t& dst, float2 src)
|
520 |
+
{
|
521 |
+
dst = float2_to_half2(src);
|
522 |
+
}
|
523 |
+
|
524 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
525 |
+
#ifdef ENABLE_BF16
|
526 |
+
inline __device__ void convert_from_float(__nv_bfloat16& dst, float src)
|
527 |
+
{
|
528 |
+
dst = __float2bfloat16(src);
|
529 |
+
}
|
530 |
+
|
531 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
532 |
+
|
533 |
+
inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src)
|
534 |
+
{
|
535 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
536 |
+
dst = __float22bfloat162_rn(src);
|
537 |
+
#else
|
538 |
+
dst = __floats2bfloat162_rn(src.x, src.y);
|
539 |
+
#endif
|
540 |
+
}
|
541 |
+
#endif // ENABLE_BF16
|
542 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
543 |
+
|
544 |
+
inline __device__ void convert_from_float(uint2& dst, Float4_ src)
|
545 |
+
{
|
546 |
+
dst.x = float2_to_half2(src.x);
|
547 |
+
dst.y = float2_to_half2(src.y);
|
548 |
+
}
|
549 |
+
|
550 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
551 |
+
|
552 |
+
inline __device__ void convert_from_float(uint2& dst, float4 src)
|
553 |
+
{
|
554 |
+
convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
|
555 |
+
}
|
556 |
+
|
557 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
558 |
+
|
559 |
+
inline __device__ void convert_from_float(uint4& dst, Float8_ src)
|
560 |
+
{
|
561 |
+
dst.x = float2_to_half2(src.x);
|
562 |
+
dst.y = float2_to_half2(src.y);
|
563 |
+
dst.z = float2_to_half2(src.z);
|
564 |
+
dst.w = float2_to_half2(src.w);
|
565 |
+
}
|
566 |
+
|
567 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
568 |
+
|
569 |
+
#ifdef ENABLE_BF16
|
570 |
+
inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src)
|
571 |
+
{
|
572 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
573 |
+
dst.x = __float22bfloat162_rn(src.x);
|
574 |
+
dst.y = __float22bfloat162_rn(src.y);
|
575 |
+
#else
|
576 |
+
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
|
577 |
+
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
|
578 |
+
#endif
|
579 |
+
}
|
580 |
+
|
581 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
582 |
+
|
583 |
+
inline __device__ void convert_from_float(bf16_4_t& dst, float4 src)
|
584 |
+
{
|
585 |
+
convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
|
586 |
+
}
|
587 |
+
|
588 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
589 |
+
|
590 |
+
inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
|
591 |
+
{
|
592 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
593 |
+
dst.x = __float22bfloat162_rn(src.x);
|
594 |
+
dst.y = __float22bfloat162_rn(src.y);
|
595 |
+
dst.z = __float22bfloat162_rn(src.z);
|
596 |
+
dst.w = __float22bfloat162_rn(src.w);
|
597 |
+
#else
|
598 |
+
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
|
599 |
+
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
|
600 |
+
dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
|
601 |
+
dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
|
602 |
+
#endif
|
603 |
+
}
|
604 |
+
#endif // ENABLE_BF16
|
605 |
+
|
606 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
607 |
+
|
608 |
+
inline __device__ void convert_from_float(float2& dst, float2 src)
|
609 |
+
{
|
610 |
+
dst = src;
|
611 |
+
}
|
612 |
+
|
613 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
614 |
+
|
615 |
+
inline __device__ void convert_from_float(float4& dst, float4 src)
|
616 |
+
{
|
617 |
+
dst = src;
|
618 |
+
}
|
619 |
+
|
620 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
621 |
+
|
622 |
+
inline __device__ float convert_to_float(float4 u)
|
623 |
+
{
|
624 |
+
return u.x;
|
625 |
+
}
|
626 |
+
|
627 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
628 |
+
|
629 |
+
inline __device__ float convert_to_float(uint4 u)
|
630 |
+
{
|
631 |
+
float2 tmp = half2_to_float2(u.x);
|
632 |
+
return tmp.x;
|
633 |
+
}
|
634 |
+
|
635 |
+
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
636 |
+
|
637 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
638 |
+
|
639 |
+
inline __device__ float cast_to_float(float u)
|
640 |
+
{
|
641 |
+
return u;
|
642 |
+
}
|
643 |
+
|
644 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
645 |
+
|
646 |
+
inline __device__ float2 cast_to_float(float2 u)
|
647 |
+
{
|
648 |
+
return u;
|
649 |
+
}
|
650 |
+
|
651 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
652 |
+
|
653 |
+
inline __device__ float4 cast_to_float(float4 u)
|
654 |
+
{
|
655 |
+
return u;
|
656 |
+
}
|
657 |
+
|
658 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
659 |
+
|
660 |
+
inline __device__ Float4_ cast_to_float(Float4_ u)
|
661 |
+
{
|
662 |
+
return u;
|
663 |
+
}
|
664 |
+
|
665 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
666 |
+
|
667 |
+
inline __device__ Float8_ cast_to_float(Float8_ u)
|
668 |
+
{
|
669 |
+
return u;
|
670 |
+
}
|
671 |
+
|
672 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
673 |
+
|
674 |
+
inline __device__ float2 cast_to_float(uint32_t u)
|
675 |
+
{
|
676 |
+
return half2_to_float2(u);
|
677 |
+
}
|
678 |
+
|
679 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
680 |
+
|
681 |
+
inline __device__ Float4_ cast_to_float(uint2 u)
|
682 |
+
{
|
683 |
+
Float4_ tmp;
|
684 |
+
tmp.x = half2_to_float2(u.x);
|
685 |
+
tmp.y = half2_to_float2(u.y);
|
686 |
+
return tmp;
|
687 |
+
}
|
688 |
+
|
689 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
690 |
+
|
691 |
+
inline __device__ Float8_ cast_to_float(uint4 u)
|
692 |
+
{
|
693 |
+
Float8_ tmp;
|
694 |
+
tmp.x = half2_to_float2(u.x);
|
695 |
+
tmp.y = half2_to_float2(u.y);
|
696 |
+
tmp.z = half2_to_float2(u.z);
|
697 |
+
tmp.w = half2_to_float2(u.w);
|
698 |
+
return tmp;
|
699 |
+
}
|
700 |
+
|
701 |
+
#endif
|
702 |
+
|
703 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
704 |
+
|
705 |
+
inline __device__ float float_from_int8(int8_t u)
|
706 |
+
{
|
707 |
+
return u;
|
708 |
+
}
|
709 |
+
|
710 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
711 |
+
|
712 |
+
inline __device__ float2 float_from_int8(int16_t u)
|
713 |
+
{
|
714 |
+
union {
|
715 |
+
int16_t int16;
|
716 |
+
int8_t int8[2];
|
717 |
+
};
|
718 |
+
int16 = u;
|
719 |
+
return make_float2(int8[0], int8[1]);
|
720 |
+
}
|
721 |
+
|
722 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
723 |
+
|
724 |
+
inline __device__ float4 float_from_int8(int32_t u)
|
725 |
+
{
|
726 |
+
union {
|
727 |
+
int32_t int32;
|
728 |
+
int8_t int8[4];
|
729 |
+
};
|
730 |
+
int32 = u;
|
731 |
+
return make_float4(int8[0], int8[1], int8[2], int8[3]);
|
732 |
+
}
|
733 |
+
|
734 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
735 |
+
|
736 |
+
// clang-format off
|
737 |
+
inline __device__ Float8_ float_from_int8(int64_t u)
|
738 |
+
{
|
739 |
+
union {
|
740 |
+
int64_t int64;
|
741 |
+
int16_t int16[4];
|
742 |
+
};
|
743 |
+
int64 = u;
|
744 |
+
return Float8_ {float_from_int8(int16[0]),
|
745 |
+
float_from_int8(int16[1]),
|
746 |
+
float_from_int8(int16[2]),
|
747 |
+
float_from_int8(int16[3])};
|
748 |
+
}
|
749 |
+
// clang-format on
|
750 |
+
|
751 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
752 |
+
|
753 |
+
inline __device__ int8_t cast_to_int8(float val)
|
754 |
+
{
|
755 |
+
union {
|
756 |
+
int8_t int8[2];
|
757 |
+
int16_t int16;
|
758 |
+
};
|
759 |
+
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
|
760 |
+
return int8[0];
|
761 |
+
}
|
762 |
+
|
763 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
764 |
+
|
765 |
+
inline __device__ int32_t cast_to_int8(float4 val)
|
766 |
+
{
|
767 |
+
union {
|
768 |
+
int8_t int8[4];
|
769 |
+
int32_t int32;
|
770 |
+
};
|
771 |
+
int8[0] = cast_to_int8(val.x);
|
772 |
+
int8[1] = cast_to_int8(val.y);
|
773 |
+
int8[2] = cast_to_int8(val.z);
|
774 |
+
int8[3] = cast_to_int8(val.w);
|
775 |
+
return int32;
|
776 |
+
}
|
777 |
+
|
778 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
779 |
+
|
780 |
+
inline __device__ int64_t cast_to_int8(Float8_ val)
|
781 |
+
{
|
782 |
+
union {
|
783 |
+
int8_t int8[8];
|
784 |
+
int64_t int64;
|
785 |
+
};
|
786 |
+
int8[0] = cast_to_int8(val.x.x);
|
787 |
+
int8[1] = cast_to_int8(val.x.y);
|
788 |
+
int8[2] = cast_to_int8(val.y.x);
|
789 |
+
int8[3] = cast_to_int8(val.y.y);
|
790 |
+
int8[4] = cast_to_int8(val.z.x);
|
791 |
+
int8[5] = cast_to_int8(val.z.y);
|
792 |
+
int8[6] = cast_to_int8(val.w.x);
|
793 |
+
int8[7] = cast_to_int8(val.w.y);
|
794 |
+
return int64;
|
795 |
+
}
|
796 |
+
|
797 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
798 |
+
|
799 |
+
template<typename T>
|
800 |
+
inline __device__ __host__ T div_up(T m, T n)
|
801 |
+
{
|
802 |
+
return (m + n - 1) / n;
|
803 |
+
}
|
804 |
+
|
805 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
806 |
+
|
807 |
+
template<typename T, bool DO_CROSS_ATTENTION>
|
808 |
+
inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
|
809 |
+
int threads_per_value,
|
810 |
+
int threads_per_block)
|
811 |
+
{
|
812 |
+
// The amount of shared memory needed to store the Q*K^T values in float.
|
813 |
+
const int max_timesteps = min(params.timestep, params.memory_max_len);
|
814 |
+
size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
|
815 |
+
|
816 |
+
// The extra memory needed if we are not using floats for the final logits.
|
817 |
+
size_t logits_sz = 0;
|
818 |
+
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
|
819 |
+
if (sizeof(T) != 4) {
|
820 |
+
// TDOD
|
821 |
+
logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) :
|
822 |
+
div_up(max_timesteps + 1, 4) * 4 * sizeof(T);
|
823 |
+
}
|
824 |
+
#endif
|
825 |
+
|
826 |
+
// The total size needed during softmax.
|
827 |
+
size_t softmax_sz = qk_sz + logits_sz;
|
828 |
+
|
829 |
+
// The number of partial rows to reduce in the final reduction.
|
830 |
+
int rows_per_red = threads_per_block / threads_per_value;
|
831 |
+
// The amount of storage needed to finalize the outputs.
|
832 |
+
size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2;
|
833 |
+
|
834 |
+
size_t transpose_rotary_size = 0;
|
835 |
+
if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
|
836 |
+
transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T);
|
837 |
+
}
|
838 |
+
|
839 |
+
// The max.
|
840 |
+
return max(max(softmax_sz, red_sz), transpose_rotary_size);
|
841 |
+
}
|
842 |
+
|
843 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
844 |
+
|
845 |
+
inline __device__ constexpr uint32_t shfl_mask(int threads)
|
846 |
+
{
|
847 |
+
return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
|
848 |
+
}
|
849 |
+
|
850 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
851 |
+
|
852 |
+
template<
|
853 |
+
// The type of the inputs. Supported types: float and half.
|
854 |
+
typename T,
|
855 |
+
// The hidden dimension per head.
|
856 |
+
int Dh,
|
857 |
+
int Dh_MAX,
|
858 |
+
// The number of threads per key.
|
859 |
+
int THREADS_PER_KEY,
|
860 |
+
// The number of threads per value.
|
861 |
+
int THREADS_PER_VALUE,
|
862 |
+
// The number of threads in a threadblock.
|
863 |
+
int THREADS_PER_BLOCK,
|
864 |
+
bool DO_CROSS_ATTENTION>
|
865 |
+
__global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, DO_CROSS_ATTENTION> params)
|
866 |
+
{
|
867 |
+
|
868 |
+
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
|
869 |
+
static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
|
870 |
+
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
|
871 |
+
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
|
872 |
+
|
873 |
+
// The size of a warp.
|
874 |
+
constexpr int WARP_SIZE = 32;
|
875 |
+
// The number of warps in a threadblock.
|
876 |
+
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
|
877 |
+
|
878 |
+
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
|
879 |
+
extern __shared__ char smem_[];
|
880 |
+
|
881 |
+
// The shared memory for the Q*K^T values and partial logits in softmax.
|
882 |
+
float* qk_smem = reinterpret_cast<float*>(smem_);
|
883 |
+
|
884 |
+
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
|
885 |
+
char* logits_smem_ = smem_;
|
886 |
+
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
|
887 |
+
if (sizeof(T) != 4) {
|
888 |
+
// TODO - change to tlength
|
889 |
+
const int max_timesteps = min(params.timestep, params.memory_max_len);
|
890 |
+
logits_smem_ +=
|
891 |
+
(DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
|
892 |
+
}
|
893 |
+
T* logits_smem = reinterpret_cast<T*>(logits_smem_);
|
894 |
+
#else
|
895 |
+
float* logits_smem = reinterpret_cast<float*>(logits_smem_);
|
896 |
+
#endif
|
897 |
+
|
898 |
+
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
|
899 |
+
T* out_smem = reinterpret_cast<T*>(smem_);
|
900 |
+
|
901 |
+
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
|
902 |
+
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
|
903 |
+
|
904 |
+
// A vector of Q or K elements for the current timestep.
|
905 |
+
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
|
906 |
+
|
907 |
+
// Use alignment for safely casting the shared buffers as Qk_vec.
|
908 |
+
// Shared memory to store Q inputs.
|
909 |
+
__shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
|
910 |
+
|
911 |
+
// This is one of the reasons we should have a separate kernel for cross attention
|
912 |
+
__shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
|
913 |
+
|
914 |
+
// A vector of Q or K elements for the current timestep.
|
915 |
+
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
|
916 |
+
// The number of elements per vector.
|
917 |
+
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
|
918 |
+
// Make sure the hidden size per head is a multiple of the vector size.
|
919 |
+
static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
|
920 |
+
// We will use block wide reduction if needed
|
921 |
+
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
|
922 |
+
// The number of vectors per warp.
|
923 |
+
constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
|
924 |
+
|
925 |
+
// The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
|
926 |
+
// owns x elements, we have to decompose the linear index into chunks of x values and the posi-
|
927 |
+
// tion of the thread in that chunk.
|
928 |
+
|
929 |
+
// The number of elements in a chunk of 16B (that's the x in the above formula).
|
930 |
+
constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
|
931 |
+
// The number of K vectors in 16B.
|
932 |
+
constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
|
933 |
+
|
934 |
+
// The batch/beam idx
|
935 |
+
const int bi = blockIdx.y;
|
936 |
+
if (params.finished != nullptr && params.finished[bi] == true) {
|
937 |
+
return;
|
938 |
+
}
|
939 |
+
// The beam idx
|
940 |
+
const int beami = bi % params.beam_width;
|
941 |
+
// The "beam-aware" batch idx
|
942 |
+
const int bbi = bi / params.beam_width;
|
943 |
+
// The head.
|
944 |
+
const int num_kv_heads = params.num_kv_heads;
|
945 |
+
const int kv_rep = (params.num_heads / num_kv_heads);
|
946 |
+
const int hi = blockIdx.x;
|
947 |
+
const int hi_kv = hi / kv_rep;
|
948 |
+
|
949 |
+
// Combine the batch and the head indices.
|
950 |
+
const int bhi = bi * params.num_heads + hi;
|
951 |
+
const int bhi_kv = bi * (params.num_heads / kv_rep) + hi_kv;
|
952 |
+
// Combine the "beam-aware" batch idx and the head indices.
|
953 |
+
const int bbhi = bbi * params.beam_width * params.num_heads + hi;
|
954 |
+
const int bbhi_kv = bbi * params.beam_width * (params.num_heads / kv_rep) + hi_kv;
|
955 |
+
// The thread in the block.
|
956 |
+
const int tidx = threadIdx.x;
|
957 |
+
|
958 |
+
const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);
|
959 |
+
// Every kv_rep threads have the same kv_cache values. So only the first one writes back.
|
960 |
+
const int write_kv_cache = handle_kv && (hi % kv_rep == 0);
|
961 |
+
|
962 |
+
// While doing the product Q*K^T for the different keys we track the max.
|
963 |
+
float qk_max = -FLT_MAX;
|
964 |
+
|
965 |
+
float qk = 0.0F;
|
966 |
+
|
967 |
+
// int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
|
968 |
+
const int q_base_offset = bi * params.stride + hi * Dh;
|
969 |
+
const int k_base_offset = bi * params.stride + hi_kv * Dh;
|
970 |
+
const int v_base_offset = k_base_offset;
|
971 |
+
|
972 |
+
const size_t bi_seq_len_offset = bi * params.memory_max_len;
|
973 |
+
|
974 |
+
// int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
|
975 |
+
int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
|
976 |
+
(params.length_per_sample == nullptr) ?
|
977 |
+
params.timestep :
|
978 |
+
params.length_per_sample[bi] + params.max_prefix_prompt_length;
|
979 |
+
const int first_step = max(0, tlength + 1 - params.memory_max_len);
|
980 |
+
const int tlength_circ = tlength % params.memory_max_len;
|
981 |
+
|
982 |
+
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
|
983 |
+
const bool is_masked = tidx >= QK_VECS_PER_WARP;
|
984 |
+
|
985 |
+
// The offset in the Q and K buffer also accounts for the batch.
|
986 |
+
// int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
|
987 |
+
int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
|
988 |
+
int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
|
989 |
+
int v_offset = k_offset;
|
990 |
+
|
991 |
+
// The offset in the bias buffer.
|
992 |
+
// int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
|
993 |
+
int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
|
994 |
+
int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
|
995 |
+
int v_bias_offset = k_bias_offset;
|
996 |
+
|
997 |
+
const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
|
998 |
+
const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
|
999 |
+
|
1000 |
+
// Trigger the loads from the Q and K buffers.
|
1001 |
+
Qk_vec q;
|
1002 |
+
zero(q);
|
1003 |
+
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
|
1004 |
+
if (params.int8_mode == 2) {
|
1005 |
+
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
|
1006 |
+
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
|
1007 |
+
const auto q_scaling = params.qkv_scale_out[0];
|
1008 |
+
const auto q_quant =
|
1009 |
+
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
|
1010 |
+
|
1011 |
+
convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
|
1012 |
+
}
|
1013 |
+
else {
|
1014 |
+
q = *reinterpret_cast<const Qk_vec*>(¶ms.q[q_offset]);
|
1015 |
+
}
|
1016 |
+
}
|
1017 |
+
|
1018 |
+
Qk_vec k;
|
1019 |
+
zero(k);
|
1020 |
+
if (DO_CROSS_ATTENTION) {
|
1021 |
+
// The 16B chunk written by the thread.
|
1022 |
+
int co = tidx / QK_VECS_IN_16B;
|
1023 |
+
// The position of the thread in that 16B chunk.
|
1024 |
+
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
|
1025 |
+
|
1026 |
+
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
|
1027 |
+
int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
|
1028 |
+
// params.timestep*QK_ELTS_IN_16B +
|
1029 |
+
tlength * QK_ELTS_IN_16B + ci;
|
1030 |
+
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
|
1031 |
+
*reinterpret_cast<const Qk_vec*>(¶ms.k_cache[offset]) :
|
1032 |
+
k;
|
1033 |
+
}
|
1034 |
+
else {
|
1035 |
+
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
|
1036 |
+
if (params.int8_mode == 2) {
|
1037 |
+
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
|
1038 |
+
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
|
1039 |
+
const auto k_scaling = params.qkv_scale_out[1];
|
1040 |
+
const auto k_quant =
|
1041 |
+
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
|
1042 |
+
|
1043 |
+
convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
|
1044 |
+
}
|
1045 |
+
else {
|
1046 |
+
k = *reinterpret_cast<const Qk_vec*>(¶ms.k[k_offset]);
|
1047 |
+
}
|
1048 |
+
}
|
1049 |
+
}
|
1050 |
+
|
1051 |
+
// Trigger the loads from the Q and K bias buffers.
|
1052 |
+
Qk_vec q_bias;
|
1053 |
+
zero(q_bias);
|
1054 |
+
q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
|
1055 |
+
*reinterpret_cast<const Qk_vec*>(¶ms.q_bias[q_bias_offset]) :
|
1056 |
+
q_bias;
|
1057 |
+
|
1058 |
+
Qk_vec k_bias;
|
1059 |
+
zero(k_bias);
|
1060 |
+
if (handle_kv) {
|
1061 |
+
k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
|
1062 |
+
*reinterpret_cast<const Qk_vec*>(¶ms.k_bias[k_bias_offset]) :
|
1063 |
+
k_bias;
|
1064 |
+
}
|
1065 |
+
|
1066 |
+
// Computes the Q/K values with bias.
|
1067 |
+
q = add(q, q_bias);
|
1068 |
+
if (handle_kv) {
|
1069 |
+
k = add(k, k_bias);
|
1070 |
+
}
|
1071 |
+
if (do_ia3 && !is_masked) {
|
1072 |
+
k = mul<Qk_vec, Qk_vec, Qk_vec>(
|
1073 |
+
k,
|
1074 |
+
*reinterpret_cast<const Qk_vec*>(
|
1075 |
+
¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]));
|
1076 |
+
}
|
1077 |
+
|
1078 |
+
// Padded len
|
1079 |
+
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
|
1080 |
+
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
|
1081 |
+
if (handle_kv) {
|
1082 |
+
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
1083 |
+
}
|
1084 |
+
else {
|
1085 |
+
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
1086 |
+
}
|
1087 |
+
}
|
1088 |
+
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
|
1089 |
+
const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
|
1090 |
+
|
1091 |
+
T* q_smem = reinterpret_cast<T*>(smem_);
|
1092 |
+
T* k_smem = q_smem + params.rotary_embedding_dim;
|
1093 |
+
|
1094 |
+
const int half_rotary_dim = params.rotary_embedding_dim / 2;
|
1095 |
+
const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim;
|
1096 |
+
const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim;
|
1097 |
+
const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
|
1098 |
+
|
1099 |
+
assert(half_rotary_dim % QK_VEC_SIZE == 0);
|
1100 |
+
|
1101 |
+
if (do_rotary) {
|
1102 |
+
*reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
|
1103 |
+
|
1104 |
+
if (handle_kv) {
|
1105 |
+
*reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
|
1106 |
+
}
|
1107 |
+
}
|
1108 |
+
|
1109 |
+
__syncthreads();
|
1110 |
+
|
1111 |
+
const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
|
1112 |
+
constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
|
1113 |
+
if (do_rotary) {
|
1114 |
+
mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
1115 |
+
|
1116 |
+
if (handle_kv) {
|
1117 |
+
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
1118 |
+
|
1119 |
+
mmha::apply_rotary_embedding(
|
1120 |
+
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
1121 |
+
|
1122 |
+
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
1123 |
+
}
|
1124 |
+
else {
|
1125 |
+
mmha::apply_rotary_embedding(
|
1126 |
+
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
|
1127 |
+
}
|
1128 |
+
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
1129 |
+
}
|
1130 |
+
|
1131 |
+
__syncthreads();
|
1132 |
+
|
1133 |
+
if (do_rotary) {
|
1134 |
+
q = *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx);
|
1135 |
+
if (handle_kv) {
|
1136 |
+
k = *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx);
|
1137 |
+
}
|
1138 |
+
}
|
1139 |
+
|
1140 |
+
__syncthreads();
|
1141 |
+
}
|
1142 |
+
|
1143 |
+
if (!is_masked) {
|
1144 |
+
// Store the Q values to shared memory.
|
1145 |
+
*reinterpret_cast<Qk_vec*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
|
1146 |
+
|
1147 |
+
// Store Dh values of k_bias into smem, since will need to add later
|
1148 |
+
// if params.timestep == 0
|
1149 |
+
if (DO_CROSS_ATTENTION && params.timestep == 0) {
|
1150 |
+
*reinterpret_cast<Qk_vec*>(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
|
1151 |
+
}
|
1152 |
+
|
1153 |
+
// Write the K values to the global memory cache.
|
1154 |
+
//
|
1155 |
+
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
|
1156 |
+
// system. We designed it this way as it allows much better memory loads (and there are many
|
1157 |
+
// more loads) + the stores are really "write and forget" since we won't need the ack before
|
1158 |
+
// the end of the kernel. There's plenty of time for the transactions to complete.
|
1159 |
+
|
1160 |
+
// The 16B chunk written by the thread.
|
1161 |
+
int co = tidx / QK_VECS_IN_16B;
|
1162 |
+
// The position of the thread in that 16B chunk.
|
1163 |
+
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
|
1164 |
+
|
1165 |
+
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
|
1166 |
+
int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
|
1167 |
+
// params.timestep*QK_ELTS_IN_16B +
|
1168 |
+
tlength_circ * QK_ELTS_IN_16B + ci;
|
1169 |
+
|
1170 |
+
if (write_kv_cache) {
|
1171 |
+
// Trigger the stores to global memory.
|
1172 |
+
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
|
1173 |
+
*reinterpret_cast<Qk_vec*>(¶ms.k_cache[offset]) = k;
|
1174 |
+
}
|
1175 |
+
}
|
1176 |
+
|
1177 |
+
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
|
1178 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
1179 |
+
using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec>::Type;
|
1180 |
+
#else
|
1181 |
+
using Qk_vec_acum = Qk_vec;
|
1182 |
+
#endif
|
1183 |
+
qk = dot<Qk_vec_acum, Qk_vec>(q, k);
|
1184 |
+
if (QK_VECS_PER_WARP <= WARP_SIZE) {
|
1185 |
+
#pragma unroll
|
1186 |
+
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
|
1187 |
+
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
|
1188 |
+
}
|
1189 |
+
}
|
1190 |
+
}
|
1191 |
+
|
1192 |
+
if (QK_VECS_PER_WARP > WARP_SIZE) {
|
1193 |
+
constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
|
1194 |
+
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
|
1195 |
+
}
|
1196 |
+
|
1197 |
+
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
|
1198 |
+
if (tidx == 0) {
|
1199 |
+
// Normalize qk.
|
1200 |
+
qk *= params.inv_sqrt_dh;
|
1201 |
+
if (params.relative_attention_bias != nullptr) {
|
1202 |
+
// TODO (Haotian): check whether we should replace hi with hi_kv,
|
1203 |
+
// although params.relative_attention_bias is usually not used.
|
1204 |
+
qk = add(qk,
|
1205 |
+
params.relative_attention_bias[hi * params.relative_attention_bias_stride
|
1206 |
+
* params.relative_attention_bias_stride
|
1207 |
+
+ (tlength - padd_len) * params.relative_attention_bias_stride
|
1208 |
+
+ (tlength - padd_len)]);
|
1209 |
+
}
|
1210 |
+
// Add alibi positional encoding
|
1211 |
+
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
|
1212 |
+
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
|
1213 |
+
|
1214 |
+
qk_max = qk;
|
1215 |
+
qk_smem[tlength - first_step] = qk;
|
1216 |
+
// qk_smem[params.timestep] = qk;
|
1217 |
+
}
|
1218 |
+
|
1219 |
+
// Make sure the data is in shared memory.
|
1220 |
+
__syncthreads();
|
1221 |
+
|
1222 |
+
// The type of queries and keys for the math in the Q*K^T product.
|
1223 |
+
using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
|
1224 |
+
// The number of elements per vector.
|
1225 |
+
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
|
1226 |
+
// Make sure the hidden size per head is a multiple of the vector size.
|
1227 |
+
static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
|
1228 |
+
// The number of elements per thread.
|
1229 |
+
constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
|
1230 |
+
// The number of vectors per thread.
|
1231 |
+
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
|
1232 |
+
|
1233 |
+
// The position the first key loaded by each thread from the cache buffer (for this B * H).
|
1234 |
+
int ko = tidx / THREADS_PER_KEY;
|
1235 |
+
// The position of the thread in the chunk of keys.
|
1236 |
+
int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
|
1237 |
+
|
1238 |
+
static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
|
1239 |
+
|
1240 |
+
// Load the Q values from shared memory. The values are reused during the loop on K.
|
1241 |
+
K_vec q_vec[K_VECS_PER_THREAD];
|
1242 |
+
#pragma unroll
|
1243 |
+
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
|
1244 |
+
q_vec[ii] = *reinterpret_cast<const K_vec*>(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
|
1245 |
+
}
|
1246 |
+
|
1247 |
+
K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1];
|
1248 |
+
if (DO_CROSS_ATTENTION && params.timestep == 0) {
|
1249 |
+
#pragma unroll
|
1250 |
+
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
|
1251 |
+
k_bias_vec[ii] = *reinterpret_cast<const K_vec*>(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
|
1252 |
+
}
|
1253 |
+
}
|
1254 |
+
|
1255 |
+
// The number of timesteps loaded per iteration.
|
1256 |
+
constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
|
1257 |
+
// The number of keys per warp.
|
1258 |
+
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
|
1259 |
+
|
1260 |
+
// The base pointer for the key in the cache buffer.
|
1261 |
+
T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
|
1262 |
+
// Base pointer for the beam's batch, before offsetting with indirection buffer
|
1263 |
+
T* k_cache_batch = ¶ms.k_cache[bbhi_kv * params.memory_max_len * Dh + ki];
|
1264 |
+
|
1265 |
+
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
|
1266 |
+
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
|
1267 |
+
int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
|
1268 |
+
|
1269 |
+
// prefix prompt length if has
|
1270 |
+
const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
|
1271 |
+
|
1272 |
+
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
|
1273 |
+
const bool has_beams = params.cache_indir != nullptr;
|
1274 |
+
const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr;
|
1275 |
+
|
1276 |
+
for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
|
1277 |
+
const int ti_circ = ti % params.memory_max_len;
|
1278 |
+
|
1279 |
+
// The keys loaded from the key cache.
|
1280 |
+
K_vec k[K_VECS_PER_THREAD];
|
1281 |
+
K_vec k_vec_zero;
|
1282 |
+
zero(k_vec_zero);
|
1283 |
+
#pragma unroll
|
1284 |
+
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
|
1285 |
+
int jj = ii * params.memory_max_len + ti_circ;
|
1286 |
+
// if( ti < params.timestep ) {
|
1287 |
+
const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
|
1288 |
+
if (ti < tlength) {
|
1289 |
+
if (!within_bounds) {
|
1290 |
+
k[ii] = k_vec_zero;
|
1291 |
+
}
|
1292 |
+
else {
|
1293 |
+
if (has_beams) {
|
1294 |
+
const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
|
1295 |
+
k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]);
|
1296 |
+
}
|
1297 |
+
else {
|
1298 |
+
k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[jj * QK_ELTS_IN_16B]);
|
1299 |
+
}
|
1300 |
+
}
|
1301 |
+
// add bias and update k_cache
|
1302 |
+
if (DO_CROSS_ATTENTION && params.timestep == 0) {
|
1303 |
+
k[ii] = add(k[ii], k_bias_vec[ii]);
|
1304 |
+
|
1305 |
+
if (do_ia3) {
|
1306 |
+
k[ii] = mul<K_vec, K_vec, K_vec>(
|
1307 |
+
k[ii],
|
1308 |
+
*reinterpret_cast<const K_vec*>(
|
1309 |
+
¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki
|
1310 |
+
+ ii * THREADS_PER_KEY * K_VEC_SIZE]));
|
1311 |
+
}
|
1312 |
+
|
1313 |
+
if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) {
|
1314 |
+
*reinterpret_cast<K_vec*>(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii];
|
1315 |
+
}
|
1316 |
+
}
|
1317 |
+
}
|
1318 |
+
}
|
1319 |
+
|
1320 |
+
// Perform the dot product and normalize qk.
|
1321 |
+
//
|
1322 |
+
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
|
1323 |
+
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k) * params.inv_sqrt_dh;
|
1324 |
+
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
|
1325 |
+
|
1326 |
+
// Store the product to shared memory. There's one qk value per timestep. Update the max.
|
1327 |
+
// if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
|
1328 |
+
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
|
1329 |
+
if (params.relative_attention_bias != nullptr) {
|
1330 |
+
qk = add(qk,
|
1331 |
+
params.relative_attention_bias[hi * params.relative_attention_bias_stride
|
1332 |
+
* params.relative_attention_bias_stride
|
1333 |
+
+ tlength * params.relative_attention_bias_stride + ti]);
|
1334 |
+
}
|
1335 |
+
if (params.linear_bias_slopes != nullptr) {
|
1336 |
+
// Apply the linear position bias: (ki - qi) * slope[hi].
|
1337 |
+
// The padding token locates between the input context and the generated tokens.
|
1338 |
+
// We need to remove the number of padding tokens in the distance computation.
|
1339 |
+
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
|
1340 |
+
// token: i i i i p p p o o o where i=input, p=pad, o=output.
|
1341 |
+
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
|
1342 |
+
int max_context_length = params.max_prefix_prompt_length + params.max_input_length;
|
1343 |
+
float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength;
|
1344 |
+
|
1345 |
+
qk += mul<float, float, float>(params.linear_bias_slopes[hi], dist);
|
1346 |
+
}
|
1347 |
+
// Add alibi positional encoding
|
1348 |
+
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
|
1349 |
+
qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
|
1350 |
+
qk_smem[ti - first_step] = qk;
|
1351 |
+
}
|
1352 |
+
}
|
1353 |
+
|
1354 |
+
// Perform the final reduction to compute the max inside each warp.
|
1355 |
+
//
|
1356 |
+
// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
|
1357 |
+
// group so it's not needed to run the reduction inside the group (again).
|
1358 |
+
#pragma unroll
|
1359 |
+
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
|
1360 |
+
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
1361 |
+
}
|
1362 |
+
|
1363 |
+
// Decompose the thread index into warp and lane.
|
1364 |
+
const int warp = tidx / WARP_SIZE;
|
1365 |
+
const int lane = tidx % WARP_SIZE;
|
1366 |
+
|
1367 |
+
// The warp leader writes the max to shared memory.
|
1368 |
+
if (lane == 0) {
|
1369 |
+
red_smem[warp] = qk_max;
|
1370 |
+
}
|
1371 |
+
|
1372 |
+
// Make sure the products are in shared memory.
|
1373 |
+
__syncthreads();
|
1374 |
+
|
1375 |
+
// The warps finalize the reduction.
|
1376 |
+
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
|
1377 |
+
#pragma unroll
|
1378 |
+
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
1379 |
+
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
1380 |
+
}
|
1381 |
+
|
1382 |
+
// Broadcast to all the threads in the warp.
|
1383 |
+
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
1384 |
+
|
1385 |
+
// Compute the logits and start the sum.
|
1386 |
+
float sum = 0.f;
|
1387 |
+
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
|
1388 |
+
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
|
1389 |
+
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
|
1390 |
+
float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
|
1391 |
+
sum += logit;
|
1392 |
+
qk_smem[ti - first_step] = logit;
|
1393 |
+
}
|
1394 |
+
|
1395 |
+
// Compute the sum.
|
1396 |
+
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
|
1397 |
+
|
1398 |
+
// Normalize the logits.
|
1399 |
+
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
|
1400 |
+
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
|
1401 |
+
const size_t cross_attention_out_offset =
|
1402 |
+
params.is_return_cross_attentions ?
|
1403 |
+
bhi_kv * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len :
|
1404 |
+
0;
|
1405 |
+
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
|
1406 |
+
float logit = qk_smem[ti - first_step] * inv_sum;
|
1407 |
+
if (params.is_return_cross_attentions) {
|
1408 |
+
params.cross_attention_out[cross_attention_out_offset + ti] = logit;
|
1409 |
+
}
|
1410 |
+
convert_from_float(logits_smem[ti - first_step], logit);
|
1411 |
+
}
|
1412 |
+
|
1413 |
+
// Put Values part below so we leverage __syncthreads
|
1414 |
+
// from the previous step
|
1415 |
+
|
1416 |
+
// The number of elements per vector.
|
1417 |
+
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
|
1418 |
+
// A vector of V elements for the current timestep.
|
1419 |
+
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
|
1420 |
+
|
1421 |
+
// The value computed by this thread.
|
1422 |
+
int vo = tidx / THREADS_PER_VALUE;
|
1423 |
+
// The hidden dimensions computed by this particular thread.
|
1424 |
+
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
|
1425 |
+
|
1426 |
+
// The base pointer for the value in the cache buffer.
|
1427 |
+
T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
|
1428 |
+
// Base pointer for the beam's batch, before offsetting with indirection buffer
|
1429 |
+
T* v_cache_batch = ¶ms.v_cache[bbhi_kv * params.memory_max_len * Dh + vi];
|
1430 |
+
|
1431 |
+
// The number of values processed per iteration of the loop.
|
1432 |
+
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
|
1433 |
+
|
1434 |
+
// One group of threads computes the product(s) for the current timestep.
|
1435 |
+
V_vec v_bias;
|
1436 |
+
zero(v_bias);
|
1437 |
+
// if( vo == params.timestep % V_PER_ITER ) {
|
1438 |
+
if (Dh == Dh_MAX || vi < Dh) {
|
1439 |
+
if (handle_kv) {
|
1440 |
+
if (vo == tlength % V_PER_ITER) {
|
1441 |
+
// Trigger the loads from the V bias buffer.
|
1442 |
+
if (params.v_bias != nullptr) {
|
1443 |
+
v_bias = *reinterpret_cast<const V_vec*>(¶ms.v_bias[hi_kv * Dh + vi]);
|
1444 |
+
}
|
1445 |
+
if (DO_CROSS_ATTENTION) {
|
1446 |
+
*reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
|
1447 |
+
}
|
1448 |
+
}
|
1449 |
+
}
|
1450 |
+
}
|
1451 |
+
|
1452 |
+
// From previous, before values, step
|
1453 |
+
// Also make sure the logits are in shared memory.
|
1454 |
+
__syncthreads();
|
1455 |
+
|
1456 |
+
// Values continued
|
1457 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
1458 |
+
using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
|
1459 |
+
#else
|
1460 |
+
using V_vec_acum = V_vec;
|
1461 |
+
#endif
|
1462 |
+
// The partial outputs computed by each thread.
|
1463 |
+
V_vec_acum out;
|
1464 |
+
zero(out);
|
1465 |
+
|
1466 |
+
// Loop over the timesteps to compute the partial outputs.
|
1467 |
+
// for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
|
1468 |
+
if (Dh == Dh_MAX || vi < Dh) {
|
1469 |
+
for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
|
1470 |
+
const int ti_circ = ti % params.memory_max_len;
|
1471 |
+
|
1472 |
+
// Fetch offset based on cache_indir when beam sampling
|
1473 |
+
const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
|
1474 |
+
const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh;
|
1475 |
+
// Load the values from the cache.
|
1476 |
+
V_vec v = *reinterpret_cast<const V_vec*>(&v_cache_batch[beam_offset + ti_circ * Dh]);
|
1477 |
+
if (DO_CROSS_ATTENTION && params.timestep == 0) {
|
1478 |
+
v = add(v, *reinterpret_cast<V_vec*>(&bias_smem[vi]));
|
1479 |
+
if (do_ia3) {
|
1480 |
+
v = mul<V_vec, V_vec, V_vec>(
|
1481 |
+
v,
|
1482 |
+
*reinterpret_cast<const V_vec*>(
|
1483 |
+
¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
|
1484 |
+
}
|
1485 |
+
*reinterpret_cast<V_vec*>(&v_cache[ti * Dh]) = v;
|
1486 |
+
}
|
1487 |
+
// Load the logits from shared memory.
|
1488 |
+
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
1489 |
+
float logit = logits_smem[ti - first_step];
|
1490 |
+
out = fma(logit, cast_to_float(v), out);
|
1491 |
+
#else
|
1492 |
+
T logit = logits_smem[ti - first_step];
|
1493 |
+
|
1494 |
+
// Update the partial sums.
|
1495 |
+
out = fma(logit, v, out);
|
1496 |
+
#endif
|
1497 |
+
}
|
1498 |
+
}
|
1499 |
+
|
1500 |
+
// One group of threads computes the product(s) for the current timestep.
|
1501 |
+
// if( vo == params.timestep % V_PER_ITER ) {
|
1502 |
+
if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
|
1503 |
+
|
1504 |
+
V_vec v;
|
1505 |
+
if (DO_CROSS_ATTENTION) {
|
1506 |
+
v = *reinterpret_cast<const V_vec*>(&v_cache[tlength * Dh]);
|
1507 |
+
}
|
1508 |
+
else {
|
1509 |
+
// Trigger the loads from the V buffer.
|
1510 |
+
const auto v_offset = v_base_offset + vi;
|
1511 |
+
if (params.int8_mode == 2) {
|
1512 |
+
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
|
1513 |
+
using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
|
1514 |
+
const auto v_scaling = params.qkv_scale_out[2];
|
1515 |
+
const auto v_quant =
|
1516 |
+
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
|
1517 |
+
|
1518 |
+
convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
|
1519 |
+
}
|
1520 |
+
else {
|
1521 |
+
v = *reinterpret_cast<const V_vec*>(¶ms.v[v_offset]);
|
1522 |
+
}
|
1523 |
+
// Trigger the loads from the V bias buffer.
|
1524 |
+
// V_vec v_bias = *reinterpret_cast<const V_vec*>(¶ms.v_bias[hi*Dh + vi]);
|
1525 |
+
}
|
1526 |
+
|
1527 |
+
// Compute the V values with bias.
|
1528 |
+
v = add(v, v_bias);
|
1529 |
+
if (write_kv_cache) {
|
1530 |
+
|
1531 |
+
if (do_ia3) {
|
1532 |
+
v = mul<V_vec, V_vec, V_vec>(
|
1533 |
+
v,
|
1534 |
+
*reinterpret_cast<const V_vec*>(
|
1535 |
+
¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
|
1536 |
+
}
|
1537 |
+
|
1538 |
+
// Store the values with bias back to global memory in the cache for V.
|
1539 |
+
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
|
1540 |
+
*reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
|
1541 |
+
}
|
1542 |
+
|
1543 |
+
// Initialize the output value with the current timestep.
|
1544 |
+
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
|
1545 |
+
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
|
1546 |
+
out = fma(logits_smem[tlength - first_step], cast_to_float(v), out);
|
1547 |
+
#else
|
1548 |
+
// out = fma(logits_smem[params.timestep], v, out);
|
1549 |
+
out = fma(logits_smem[tlength - first_step], v, out);
|
1550 |
+
#endif
|
1551 |
+
}
|
1552 |
+
|
1553 |
+
// Make sure we can start writing to shared memory.
|
1554 |
+
__syncthreads();
|
1555 |
+
|
1556 |
+
// Run the final reduction amongst the different groups computing different partial outputs.
|
1557 |
+
if (Dh == Dh_MAX || vi < Dh) {
|
1558 |
+
#pragma unroll
|
1559 |
+
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
|
1560 |
+
|
1561 |
+
// The midpoint in the number of active groups.
|
1562 |
+
int midpoint = active_groups / 2;
|
1563 |
+
|
1564 |
+
// The upper part of active threads store to shared memory.
|
1565 |
+
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
|
1566 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
1567 |
+
convert_from_float(*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
|
1568 |
+
#else
|
1569 |
+
*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
|
1570 |
+
#endif
|
1571 |
+
}
|
1572 |
+
__syncthreads();
|
1573 |
+
|
1574 |
+
// The bottom warps update their values.
|
1575 |
+
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
|
1576 |
+
out = add(*reinterpret_cast<const V_vec*>(&out_smem[vo * Dh + vi]), out);
|
1577 |
+
}
|
1578 |
+
__syncthreads();
|
1579 |
+
}
|
1580 |
+
}
|
1581 |
+
|
1582 |
+
// Output the final values.
|
1583 |
+
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
|
1584 |
+
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
|
1585 |
+
if (params.int8_mode == 2) {
|
1586 |
+
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
|
1587 |
+
out = mul<V_vec_acum, float>(*params.attention_out_scale, out);
|
1588 |
+
*reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
|
1589 |
+
cast_to_int8(out);
|
1590 |
+
}
|
1591 |
+
else {
|
1592 |
+
convert_from_float(*reinterpret_cast<V_vec*>(¶ms.out[bhi * Dh + vi]), out);
|
1593 |
+
}
|
1594 |
+
#else
|
1595 |
+
// TODO: support int8_mode?
|
1596 |
+
*reinterpret_cast<V_vec*>(¶ms.out[bhi * Dh + vi]) = out;
|
1597 |
+
#endif
|
1598 |
+
}
|
1599 |
+
}
|
1600 |
+
|
1601 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1602 |
+
|
1603 |
+
} // namespace mmha
|
1604 |
+
|
1605 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1606 |
+
|
1607 |
+
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
|
1608 |
+
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream);
|
AutoAWQ_kernels/awq_ext/attention/decoder_masked_multihead_attention_utils.h
ADDED
@@ -0,0 +1,1786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Downloaded from from FasterTransformer v5.2.1
|
2 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
3 |
+
/*
|
4 |
+
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
*
|
6 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
* you may not use this file except in compliance with the License.
|
8 |
+
* You may obtain a copy of the License at
|
9 |
+
*
|
10 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
*
|
12 |
+
* Unless required by applicable law or agreed to in writing, software
|
13 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
* See the License for the specific language governing permissions and
|
16 |
+
* limitations under the License.
|
17 |
+
*/
|
18 |
+
|
19 |
+
#pragma once
|
20 |
+
|
21 |
+
#include "cuda_bf16_wrapper.h"
|
22 |
+
#include "cuda_bf16_fallbacks.cuh"
|
23 |
+
#include <stdint.h>
|
24 |
+
|
25 |
+
using namespace fastertransformer;
|
26 |
+
|
27 |
+
namespace mmha {
|
28 |
+
|
29 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
30 |
+
|
31 |
+
struct Float8_ {
|
32 |
+
float2 x;
|
33 |
+
float2 y;
|
34 |
+
float2 z;
|
35 |
+
float2 w;
|
36 |
+
};
|
37 |
+
|
38 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
39 |
+
|
40 |
+
struct Float4_ {
|
41 |
+
float2 x;
|
42 |
+
float2 y;
|
43 |
+
};
|
44 |
+
|
45 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
46 |
+
|
47 |
+
#ifdef ENABLE_BF16
|
48 |
+
struct bf16_4_t {
|
49 |
+
__nv_bfloat162 x;
|
50 |
+
__nv_bfloat162 y;
|
51 |
+
};
|
52 |
+
|
53 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
54 |
+
|
55 |
+
struct bf16_8_t {
|
56 |
+
__nv_bfloat162 x;
|
57 |
+
__nv_bfloat162 y;
|
58 |
+
__nv_bfloat162 z;
|
59 |
+
__nv_bfloat162 w;
|
60 |
+
};
|
61 |
+
#endif
|
62 |
+
|
63 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
64 |
+
|
65 |
+
template<typename T>
|
66 |
+
struct num_elems;
|
67 |
+
template<>
|
68 |
+
struct num_elems<float> {
|
69 |
+
static constexpr int value = 1;
|
70 |
+
};
|
71 |
+
template<>
|
72 |
+
struct num_elems<float2> {
|
73 |
+
static constexpr int value = 2;
|
74 |
+
};
|
75 |
+
template<>
|
76 |
+
struct num_elems<float4> {
|
77 |
+
static constexpr int value = 4;
|
78 |
+
};
|
79 |
+
template<>
|
80 |
+
struct num_elems<Float4_> {
|
81 |
+
static constexpr int value = 4;
|
82 |
+
};
|
83 |
+
template<>
|
84 |
+
struct num_elems<Float8_> {
|
85 |
+
static constexpr int value = 8;
|
86 |
+
};
|
87 |
+
|
88 |
+
template<>
|
89 |
+
struct num_elems<uint32_t> {
|
90 |
+
static constexpr int value = 2;
|
91 |
+
};
|
92 |
+
template<>
|
93 |
+
struct num_elems<uint2> {
|
94 |
+
static constexpr int value = 4;
|
95 |
+
};
|
96 |
+
template<>
|
97 |
+
struct num_elems<uint4> {
|
98 |
+
static constexpr int value = 8;
|
99 |
+
};
|
100 |
+
|
101 |
+
#ifdef ENABLE_BF16
|
102 |
+
template<>
|
103 |
+
struct num_elems<__nv_bfloat162> {
|
104 |
+
static constexpr int value = 2;
|
105 |
+
};
|
106 |
+
template<>
|
107 |
+
struct num_elems<bf16_4_t> {
|
108 |
+
static constexpr int value = 4;
|
109 |
+
};
|
110 |
+
template<>
|
111 |
+
struct num_elems<bf16_8_t> {
|
112 |
+
static constexpr int value = 8;
|
113 |
+
};
|
114 |
+
#endif
|
115 |
+
|
116 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
117 |
+
|
118 |
+
template<typename T, int N>
|
119 |
+
struct packed_type;
|
120 |
+
template<typename T>
|
121 |
+
struct packed_type<T, 1> {
|
122 |
+
using type = T;
|
123 |
+
};
|
124 |
+
template<>
|
125 |
+
struct packed_type<int8_t, 2> {
|
126 |
+
using type = int16_t;
|
127 |
+
};
|
128 |
+
template<>
|
129 |
+
struct packed_type<int8_t, 4> {
|
130 |
+
using type = int32_t;
|
131 |
+
};
|
132 |
+
template<>
|
133 |
+
struct packed_type<int8_t, 8> {
|
134 |
+
using type = int64_t;
|
135 |
+
};
|
136 |
+
|
137 |
+
template<>
|
138 |
+
struct packed_type<float, 2> {
|
139 |
+
using type = float2;
|
140 |
+
};
|
141 |
+
template<>
|
142 |
+
struct packed_type<float, 4> {
|
143 |
+
using type = float4;
|
144 |
+
};
|
145 |
+
template<>
|
146 |
+
struct packed_type<float, 8> {
|
147 |
+
using type = Float8_;
|
148 |
+
};
|
149 |
+
|
150 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
151 |
+
|
152 |
+
inline __device__ float add(float a, float b)
|
153 |
+
{
|
154 |
+
return a + b;
|
155 |
+
}
|
156 |
+
|
157 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
158 |
+
|
159 |
+
inline __device__ float2 add(float2 a, float2 b)
|
160 |
+
{
|
161 |
+
float2 c;
|
162 |
+
c.x = add(a.x, b.x);
|
163 |
+
c.y = add(a.y, b.y);
|
164 |
+
return c;
|
165 |
+
}
|
166 |
+
|
167 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
168 |
+
|
169 |
+
inline __device__ float4 add(float4 a, float4 b)
|
170 |
+
{
|
171 |
+
float4 c;
|
172 |
+
c.x = add(a.x, b.x);
|
173 |
+
c.y = add(a.y, b.y);
|
174 |
+
c.z = add(a.z, b.z);
|
175 |
+
c.w = add(a.w, b.w);
|
176 |
+
return c;
|
177 |
+
}
|
178 |
+
|
179 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
180 |
+
|
181 |
+
#ifdef ENABLE_BF16
|
182 |
+
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
|
183 |
+
{
|
184 |
+
return a + b;
|
185 |
+
}
|
186 |
+
|
187 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
188 |
+
|
189 |
+
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
|
190 |
+
{
|
191 |
+
return bf16hadd2(a, b);
|
192 |
+
}
|
193 |
+
|
194 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
195 |
+
|
196 |
+
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b)
|
197 |
+
{
|
198 |
+
bf16_4_t c;
|
199 |
+
c.x = add(a.x, b.x);
|
200 |
+
c.y = add(a.y, b.y);
|
201 |
+
return c;
|
202 |
+
}
|
203 |
+
|
204 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
205 |
+
|
206 |
+
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b)
|
207 |
+
{
|
208 |
+
bf16_8_t c;
|
209 |
+
c.x = add(a.x, b.x);
|
210 |
+
c.y = add(a.y, b.y);
|
211 |
+
c.z = add(a.z, b.z);
|
212 |
+
c.w = add(a.w, b.w);
|
213 |
+
return c;
|
214 |
+
}
|
215 |
+
#endif // ENABLE_BF16
|
216 |
+
|
217 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
218 |
+
|
219 |
+
inline __device__ uint16_t add(uint16_t a, uint16_t b)
|
220 |
+
{
|
221 |
+
uint16_t c;
|
222 |
+
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
223 |
+
return c;
|
224 |
+
}
|
225 |
+
|
226 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
227 |
+
|
228 |
+
inline __device__ uint32_t add(uint32_t a, uint32_t b)
|
229 |
+
{
|
230 |
+
uint32_t c;
|
231 |
+
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
232 |
+
return c;
|
233 |
+
}
|
234 |
+
|
235 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
236 |
+
|
237 |
+
inline __device__ uint2 add(uint2 a, uint2 b)
|
238 |
+
{
|
239 |
+
uint2 c;
|
240 |
+
c.x = add(a.x, b.x);
|
241 |
+
c.y = add(a.y, b.y);
|
242 |
+
return c;
|
243 |
+
}
|
244 |
+
|
245 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
246 |
+
|
247 |
+
inline __device__ uint4 add(uint4 a, uint4 b)
|
248 |
+
{
|
249 |
+
uint4 c;
|
250 |
+
c.x = add(a.x, b.x);
|
251 |
+
c.y = add(a.y, b.y);
|
252 |
+
c.z = add(a.z, b.z);
|
253 |
+
c.w = add(a.w, b.w);
|
254 |
+
return c;
|
255 |
+
}
|
256 |
+
|
257 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
258 |
+
|
259 |
+
inline __device__ uint16_t float_to_half(float f)
|
260 |
+
{
|
261 |
+
union {
|
262 |
+
uint32_t u32;
|
263 |
+
uint16_t u16[2];
|
264 |
+
} tmp;
|
265 |
+
#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better?
|
266 |
+
float zero = 0.f;
|
267 |
+
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
|
268 |
+
#else
|
269 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
270 |
+
#endif
|
271 |
+
return tmp.u16[0];
|
272 |
+
}
|
273 |
+
|
274 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
275 |
+
|
276 |
+
inline __device__ uint32_t float2_to_half2(float2 f)
|
277 |
+
{
|
278 |
+
union {
|
279 |
+
uint32_t u32;
|
280 |
+
uint16_t u16[2];
|
281 |
+
} tmp;
|
282 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
283 |
+
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
284 |
+
#else
|
285 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
286 |
+
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
287 |
+
#endif
|
288 |
+
return tmp.u32;
|
289 |
+
}
|
290 |
+
|
291 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
292 |
+
|
293 |
+
inline __device__ float half_to_float(uint16_t h)
|
294 |
+
{
|
295 |
+
float f;
|
296 |
+
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
297 |
+
return f;
|
298 |
+
}
|
299 |
+
|
300 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
301 |
+
|
302 |
+
inline __device__ float2 half2_to_float2(uint32_t v)
|
303 |
+
{
|
304 |
+
uint16_t lo, hi;
|
305 |
+
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
306 |
+
return make_float2(half_to_float(lo), half_to_float(hi));
|
307 |
+
}
|
308 |
+
|
309 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
310 |
+
|
311 |
+
inline __device__ float add(float a, uint16_t b)
|
312 |
+
{
|
313 |
+
return a + half_to_float(b);
|
314 |
+
}
|
315 |
+
|
316 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
317 |
+
|
318 |
+
#ifdef ENABLE_BF16
|
319 |
+
inline __device__ float add(float a, __nv_bfloat16 b)
|
320 |
+
{
|
321 |
+
return a + __bfloat162float(b);
|
322 |
+
}
|
323 |
+
#endif
|
324 |
+
|
325 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
326 |
+
|
327 |
+
inline __device__ float2 add(uint32_t a, float2 fb)
|
328 |
+
{
|
329 |
+
float2 fa = half2_to_float2(a);
|
330 |
+
return add(fa, fb);
|
331 |
+
}
|
332 |
+
|
333 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
334 |
+
|
335 |
+
inline __device__ Float4_ add(uint2 a, Float4_ fb)
|
336 |
+
{
|
337 |
+
Float4_ fc;
|
338 |
+
fc.x = add(a.x, fb.x);
|
339 |
+
fc.y = add(a.y, fb.y);
|
340 |
+
return fc;
|
341 |
+
}
|
342 |
+
|
343 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
344 |
+
|
345 |
+
inline __device__ Float8_ add(uint4 a, Float8_ fb)
|
346 |
+
{
|
347 |
+
Float8_ fc;
|
348 |
+
fc.x = add(a.x, fb.x);
|
349 |
+
fc.y = add(a.y, fb.y);
|
350 |
+
fc.z = add(a.z, fb.z);
|
351 |
+
fc.w = add(a.w, fb.w);
|
352 |
+
return fc;
|
353 |
+
}
|
354 |
+
|
355 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
356 |
+
|
357 |
+
inline __device__ uint32_t h0_h0(uint16_t a)
|
358 |
+
{
|
359 |
+
uint32_t b;
|
360 |
+
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
361 |
+
return b;
|
362 |
+
}
|
363 |
+
|
364 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
365 |
+
|
366 |
+
inline __device__ float fma(float a, float b, float c)
|
367 |
+
{
|
368 |
+
return a * b + c;
|
369 |
+
}
|
370 |
+
|
371 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
372 |
+
|
373 |
+
inline __device__ float2 fma(float2 a, float2 b, float2 c)
|
374 |
+
{
|
375 |
+
float2 d;
|
376 |
+
d.x = fma(a.x, b.x, c.x);
|
377 |
+
d.y = fma(a.y, b.y, c.y);
|
378 |
+
return d;
|
379 |
+
}
|
380 |
+
|
381 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
382 |
+
|
383 |
+
inline __device__ float2 fma(float a, float2 b, float2 c)
|
384 |
+
{
|
385 |
+
float2 d;
|
386 |
+
d.x = fma(a, b.x, c.x);
|
387 |
+
d.y = fma(a, b.y, c.y);
|
388 |
+
return d;
|
389 |
+
}
|
390 |
+
|
391 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
392 |
+
|
393 |
+
inline __device__ float4 fma(float4 a, float4 b, float4 c)
|
394 |
+
{
|
395 |
+
float4 d;
|
396 |
+
d.x = fma(a.x, b.x, c.x);
|
397 |
+
d.y = fma(a.y, b.y, c.y);
|
398 |
+
d.z = fma(a.z, b.z, c.z);
|
399 |
+
d.w = fma(a.w, b.w, c.w);
|
400 |
+
return d;
|
401 |
+
}
|
402 |
+
|
403 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
404 |
+
|
405 |
+
inline __device__ float4 fma(float a, float4 b, float4 c)
|
406 |
+
{
|
407 |
+
float4 d;
|
408 |
+
d.x = fma(a, b.x, c.x);
|
409 |
+
d.y = fma(a, b.y, c.y);
|
410 |
+
d.z = fma(a, b.z, c.z);
|
411 |
+
d.w = fma(a, b.w, c.w);
|
412 |
+
return d;
|
413 |
+
}
|
414 |
+
|
415 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
416 |
+
|
417 |
+
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c)
|
418 |
+
{
|
419 |
+
Float4_ d;
|
420 |
+
d.x = fma(a, b.x, c.x);
|
421 |
+
d.y = fma(a, b.y, c.y);
|
422 |
+
return d;
|
423 |
+
}
|
424 |
+
|
425 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
426 |
+
|
427 |
+
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c)
|
428 |
+
{
|
429 |
+
Float8_ d;
|
430 |
+
d.x = fma(a, b.x, c.x);
|
431 |
+
d.y = fma(a, b.y, c.y);
|
432 |
+
d.z = fma(a, b.z, c.z);
|
433 |
+
d.w = fma(a, b.w, c.w);
|
434 |
+
return d;
|
435 |
+
}
|
436 |
+
|
437 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
438 |
+
|
439 |
+
#ifdef ENABLE_BF16
|
440 |
+
inline __device__ float2 add(__nv_bfloat162 a, float2 fb)
|
441 |
+
{
|
442 |
+
float2 fa = bf1622float2(a);
|
443 |
+
return add(fa, fb);
|
444 |
+
}
|
445 |
+
|
446 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
447 |
+
|
448 |
+
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb)
|
449 |
+
{
|
450 |
+
Float4_ fc;
|
451 |
+
fc.x = add(a.x, fb.x);
|
452 |
+
fc.y = add(a.y, fb.y);
|
453 |
+
return fc;
|
454 |
+
}
|
455 |
+
|
456 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
457 |
+
|
458 |
+
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb)
|
459 |
+
{
|
460 |
+
Float8_ fc;
|
461 |
+
fc.x = add(a.x, fb.x);
|
462 |
+
fc.y = add(a.y, fb.y);
|
463 |
+
fc.z = add(a.z, fb.z);
|
464 |
+
fc.w = add(a.w, fb.w);
|
465 |
+
return fc;
|
466 |
+
}
|
467 |
+
#endif // ENABLE_BF16
|
468 |
+
|
469 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
470 |
+
|
471 |
+
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c)
|
472 |
+
{
|
473 |
+
uint32_t d;
|
474 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
475 |
+
return d;
|
476 |
+
}
|
477 |
+
|
478 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
479 |
+
|
480 |
+
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c)
|
481 |
+
{
|
482 |
+
return fma(h0_h0(a), b, c);
|
483 |
+
}
|
484 |
+
|
485 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
486 |
+
|
487 |
+
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c)
|
488 |
+
{
|
489 |
+
uint2 d;
|
490 |
+
d.x = fma(a.x, b.x, c.x);
|
491 |
+
d.y = fma(a.y, b.y, c.y);
|
492 |
+
return d;
|
493 |
+
}
|
494 |
+
|
495 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
496 |
+
|
497 |
+
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c)
|
498 |
+
{
|
499 |
+
uint32_t s = h0_h0(a);
|
500 |
+
uint2 d;
|
501 |
+
d.x = fma(s, b.x, c.x);
|
502 |
+
d.y = fma(s, b.y, c.y);
|
503 |
+
return d;
|
504 |
+
}
|
505 |
+
|
506 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
507 |
+
|
508 |
+
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c)
|
509 |
+
{
|
510 |
+
uint4 d;
|
511 |
+
d.x = fma(a.x, b.x, c.x);
|
512 |
+
d.y = fma(a.y, b.y, c.y);
|
513 |
+
d.z = fma(a.z, b.z, c.z);
|
514 |
+
d.w = fma(a.w, b.w, c.w);
|
515 |
+
return d;
|
516 |
+
}
|
517 |
+
|
518 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
519 |
+
|
520 |
+
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c)
|
521 |
+
{
|
522 |
+
uint32_t s = h0_h0(a);
|
523 |
+
uint4 d;
|
524 |
+
d.x = fma(s, b.x, c.x);
|
525 |
+
d.y = fma(s, b.y, c.y);
|
526 |
+
d.z = fma(s, b.z, c.z);
|
527 |
+
d.w = fma(s, b.w, c.w);
|
528 |
+
return d;
|
529 |
+
}
|
530 |
+
|
531 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
532 |
+
|
533 |
+
inline __device__ float fma(uint16_t a, uint16_t b, float fc)
|
534 |
+
{
|
535 |
+
float fa = half_to_float(a);
|
536 |
+
float fb = half_to_float(b);
|
537 |
+
return fa * fb + fc;
|
538 |
+
}
|
539 |
+
|
540 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
541 |
+
|
542 |
+
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc)
|
543 |
+
{
|
544 |
+
float2 fa = half2_to_float2(a);
|
545 |
+
float2 fb = half2_to_float2(b);
|
546 |
+
return fma(fa, fb, fc);
|
547 |
+
}
|
548 |
+
|
549 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
550 |
+
|
551 |
+
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc)
|
552 |
+
{
|
553 |
+
return fma(h0_h0(a), b, fc);
|
554 |
+
}
|
555 |
+
|
556 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
557 |
+
|
558 |
+
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc)
|
559 |
+
{
|
560 |
+
Float4_ fd;
|
561 |
+
fd.x = fma(a.x, b.x, fc.x);
|
562 |
+
fd.y = fma(a.y, b.y, fc.y);
|
563 |
+
return fd;
|
564 |
+
}
|
565 |
+
|
566 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
567 |
+
|
568 |
+
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc)
|
569 |
+
{
|
570 |
+
uint32_t s = h0_h0(a);
|
571 |
+
Float4_ fd;
|
572 |
+
fd.x = fma(s, b.x, fc.x);
|
573 |
+
fd.y = fma(s, b.y, fc.y);
|
574 |
+
return fd;
|
575 |
+
}
|
576 |
+
|
577 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
578 |
+
|
579 |
+
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc)
|
580 |
+
{
|
581 |
+
Float8_ fd;
|
582 |
+
fd.x = fma(a.x, b.x, fc.x);
|
583 |
+
fd.y = fma(a.y, b.y, fc.y);
|
584 |
+
fd.z = fma(a.z, b.z, fc.z);
|
585 |
+
fd.w = fma(a.w, b.w, fc.w);
|
586 |
+
return fd;
|
587 |
+
}
|
588 |
+
|
589 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
590 |
+
|
591 |
+
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc)
|
592 |
+
{
|
593 |
+
uint32_t s = h0_h0(a);
|
594 |
+
Float8_ fd;
|
595 |
+
fd.x = fma(s, b.x, fc.x);
|
596 |
+
fd.y = fma(s, b.y, fc.y);
|
597 |
+
fd.z = fma(s, b.z, fc.z);
|
598 |
+
fd.w = fma(s, b.w, fc.w);
|
599 |
+
return fd;
|
600 |
+
}
|
601 |
+
|
602 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
603 |
+
#ifdef ENABLE_BF16
|
604 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
605 |
+
{
|
606 |
+
return bf16hfma2(a, b, c);
|
607 |
+
}
|
608 |
+
|
609 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
610 |
+
|
611 |
+
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c)
|
612 |
+
{
|
613 |
+
return bf16hfma2(bf162bf162(a), b, c);
|
614 |
+
}
|
615 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
616 |
+
|
617 |
+
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c)
|
618 |
+
{
|
619 |
+
bf16_4_t d;
|
620 |
+
d.x = fma(a.x, b.x, c.x);
|
621 |
+
d.y = fma(a.y, b.y, c.y);
|
622 |
+
return d;
|
623 |
+
}
|
624 |
+
|
625 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
626 |
+
|
627 |
+
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c)
|
628 |
+
{
|
629 |
+
__nv_bfloat162 s = bf162bf162(a);
|
630 |
+
bf16_4_t d;
|
631 |
+
d.x = fma(s, b.x, c.x);
|
632 |
+
d.y = fma(s, b.y, c.y);
|
633 |
+
return d;
|
634 |
+
}
|
635 |
+
|
636 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
637 |
+
|
638 |
+
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c)
|
639 |
+
{
|
640 |
+
bf16_8_t d;
|
641 |
+
d.x = fma(a.x, b.x, c.x);
|
642 |
+
d.y = fma(a.y, b.y, c.y);
|
643 |
+
d.z = fma(a.z, b.z, c.z);
|
644 |
+
d.w = fma(a.w, b.w, c.w);
|
645 |
+
return d;
|
646 |
+
}
|
647 |
+
|
648 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
649 |
+
|
650 |
+
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c)
|
651 |
+
{
|
652 |
+
__nv_bfloat162 s = bf162bf162(a);
|
653 |
+
bf16_8_t d;
|
654 |
+
d.x = fma(s, b.x, c.x);
|
655 |
+
d.y = fma(s, b.y, c.y);
|
656 |
+
d.z = fma(s, b.z, c.z);
|
657 |
+
d.w = fma(s, b.w, c.w);
|
658 |
+
return d;
|
659 |
+
}
|
660 |
+
|
661 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
662 |
+
|
663 |
+
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc)
|
664 |
+
{
|
665 |
+
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
666 |
+
}
|
667 |
+
|
668 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
669 |
+
|
670 |
+
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc)
|
671 |
+
{
|
672 |
+
float2 fa = bf1622float2(a);
|
673 |
+
float2 fb = bf1622float2(b);
|
674 |
+
return fma(fa, fb, fc);
|
675 |
+
}
|
676 |
+
|
677 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
678 |
+
|
679 |
+
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc)
|
680 |
+
{
|
681 |
+
return fma(bf162bf162(a), b, fc);
|
682 |
+
}
|
683 |
+
|
684 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
685 |
+
|
686 |
+
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc)
|
687 |
+
{
|
688 |
+
Float4_ fd;
|
689 |
+
fd.x = fma(a.x, b.x, fc.x);
|
690 |
+
fd.y = fma(a.y, b.y, fc.y);
|
691 |
+
return fd;
|
692 |
+
}
|
693 |
+
|
694 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
695 |
+
|
696 |
+
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc)
|
697 |
+
{
|
698 |
+
__nv_bfloat162 s = bf162bf162(a);
|
699 |
+
Float4_ fd;
|
700 |
+
fd.x = fma(s, b.x, fc.x);
|
701 |
+
fd.y = fma(s, b.y, fc.y);
|
702 |
+
return fd;
|
703 |
+
}
|
704 |
+
|
705 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
706 |
+
|
707 |
+
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc)
|
708 |
+
{
|
709 |
+
Float8_ fd;
|
710 |
+
fd.x = fma(a.x, b.x, fc.x);
|
711 |
+
fd.y = fma(a.y, b.y, fc.y);
|
712 |
+
fd.z = fma(a.z, b.z, fc.z);
|
713 |
+
fd.w = fma(a.w, b.w, fc.w);
|
714 |
+
return fd;
|
715 |
+
}
|
716 |
+
|
717 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
718 |
+
|
719 |
+
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc)
|
720 |
+
{
|
721 |
+
__nv_bfloat162 s = bf162bf162(a);
|
722 |
+
Float8_ fd;
|
723 |
+
fd.x = fma(s, b.x, fc.x);
|
724 |
+
fd.y = fma(s, b.y, fc.y);
|
725 |
+
fd.z = fma(s, b.z, fc.z);
|
726 |
+
fd.w = fma(s, b.w, fc.w);
|
727 |
+
return fd;
|
728 |
+
}
|
729 |
+
#endif // ENABLE_BF16
|
730 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
731 |
+
|
732 |
+
template<typename Acc, typename A, typename B>
|
733 |
+
inline __device__ Acc mul(A a, B b)
|
734 |
+
{
|
735 |
+
return a * b;
|
736 |
+
}
|
737 |
+
|
738 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
739 |
+
|
740 |
+
template<>
|
741 |
+
inline __device__ float mul<float, float>(float a, float b)
|
742 |
+
{
|
743 |
+
return a * b;
|
744 |
+
}
|
745 |
+
|
746 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
747 |
+
|
748 |
+
template<>
|
749 |
+
inline __device__ float2 mul(float2 a, float2 b)
|
750 |
+
{
|
751 |
+
float2 c;
|
752 |
+
c.x = a.x * b.x;
|
753 |
+
c.y = a.y * b.y;
|
754 |
+
return c;
|
755 |
+
}
|
756 |
+
|
757 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
758 |
+
|
759 |
+
template<>
|
760 |
+
inline __device__ float2 mul(float a, float2 b)
|
761 |
+
{
|
762 |
+
float2 c;
|
763 |
+
c.x = a * b.x;
|
764 |
+
c.y = a * b.y;
|
765 |
+
return c;
|
766 |
+
}
|
767 |
+
|
768 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
769 |
+
|
770 |
+
template<>
|
771 |
+
inline __device__ float4 mul(float4 a, float4 b)
|
772 |
+
{
|
773 |
+
float4 c;
|
774 |
+
c.x = a.x * b.x;
|
775 |
+
c.y = a.y * b.y;
|
776 |
+
c.z = a.z * b.z;
|
777 |
+
c.w = a.w * b.w;
|
778 |
+
return c;
|
779 |
+
}
|
780 |
+
|
781 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
782 |
+
|
783 |
+
template<>
|
784 |
+
inline __device__ float4 mul(float a, float4 b)
|
785 |
+
{
|
786 |
+
float4 c;
|
787 |
+
c.x = a * b.x;
|
788 |
+
c.y = a * b.y;
|
789 |
+
c.z = a * b.z;
|
790 |
+
c.w = a * b.w;
|
791 |
+
return c;
|
792 |
+
}
|
793 |
+
|
794 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
795 |
+
|
796 |
+
template<>
|
797 |
+
inline __device__ Float8_ mul(float a, Float8_ b)
|
798 |
+
{
|
799 |
+
Float8_ c;
|
800 |
+
c.x = make_float2(a * b.x.x, a * b.x.y);
|
801 |
+
c.y = make_float2(a * b.y.x, a * b.y.y);
|
802 |
+
c.z = make_float2(a * b.z.x, a * b.z.y);
|
803 |
+
c.w = make_float2(a * b.w.x, a * b.w.y);
|
804 |
+
return c;
|
805 |
+
}
|
806 |
+
|
807 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
808 |
+
|
809 |
+
template<>
|
810 |
+
inline __device__ uint16_t mul(uint16_t a, uint16_t b)
|
811 |
+
{
|
812 |
+
uint16_t c;
|
813 |
+
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
814 |
+
return c;
|
815 |
+
}
|
816 |
+
|
817 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
818 |
+
|
819 |
+
template<>
|
820 |
+
inline __device__ uint32_t mul(uint32_t a, uint32_t b)
|
821 |
+
{
|
822 |
+
uint32_t c;
|
823 |
+
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
824 |
+
return c;
|
825 |
+
}
|
826 |
+
|
827 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
828 |
+
|
829 |
+
template<>
|
830 |
+
inline __device__ uint32_t mul(uint16_t a, uint32_t b)
|
831 |
+
{
|
832 |
+
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
833 |
+
}
|
834 |
+
|
835 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
836 |
+
|
837 |
+
template<>
|
838 |
+
inline __device__ uint2 mul(uint2 a, uint2 b)
|
839 |
+
{
|
840 |
+
uint2 c;
|
841 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
842 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
843 |
+
return c;
|
844 |
+
}
|
845 |
+
|
846 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
847 |
+
|
848 |
+
template<>
|
849 |
+
inline __device__ uint2 mul(uint16_t a, uint2 b)
|
850 |
+
{
|
851 |
+
uint32_t s = h0_h0(a);
|
852 |
+
uint2 c;
|
853 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
854 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
855 |
+
return c;
|
856 |
+
}
|
857 |
+
|
858 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
859 |
+
|
860 |
+
template<>
|
861 |
+
inline __device__ uint4 mul(uint4 a, uint4 b)
|
862 |
+
{
|
863 |
+
uint4 c;
|
864 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
865 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
866 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
|
867 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
|
868 |
+
return c;
|
869 |
+
}
|
870 |
+
|
871 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
872 |
+
|
873 |
+
template<>
|
874 |
+
inline __device__ uint4 mul(uint16_t a, uint4 b)
|
875 |
+
{
|
876 |
+
uint32_t s = h0_h0(a);
|
877 |
+
uint4 c;
|
878 |
+
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
879 |
+
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
880 |
+
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
|
881 |
+
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
|
882 |
+
return c;
|
883 |
+
}
|
884 |
+
|
885 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
886 |
+
|
887 |
+
template<>
|
888 |
+
inline __device__ float mul(uint16_t a, uint16_t b)
|
889 |
+
{
|
890 |
+
float fa = half_to_float(a);
|
891 |
+
float fb = half_to_float(b);
|
892 |
+
return fa * fb;
|
893 |
+
}
|
894 |
+
|
895 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
896 |
+
|
897 |
+
template<>
|
898 |
+
inline __device__ float mul(uint16_t a, float b)
|
899 |
+
{
|
900 |
+
return half_to_float(a) * b;
|
901 |
+
}
|
902 |
+
|
903 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
904 |
+
|
905 |
+
template<>
|
906 |
+
inline __device__ float2 mul(uint32_t a, uint32_t b)
|
907 |
+
{
|
908 |
+
float2 fa = half2_to_float2(a);
|
909 |
+
float2 fb = half2_to_float2(b);
|
910 |
+
return mul<float2, float2, float2>(fa, fb);
|
911 |
+
}
|
912 |
+
|
913 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
914 |
+
|
915 |
+
template<>
|
916 |
+
inline __device__ float2 mul(uint16_t a, uint32_t b)
|
917 |
+
{
|
918 |
+
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
919 |
+
}
|
920 |
+
|
921 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
922 |
+
|
923 |
+
template<>
|
924 |
+
inline __device__ Float4_ mul(uint2 a, uint2 b)
|
925 |
+
{
|
926 |
+
Float4_ fc;
|
927 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
928 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
929 |
+
return fc;
|
930 |
+
}
|
931 |
+
|
932 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
933 |
+
|
934 |
+
template<>
|
935 |
+
inline __device__ Float4_ mul(uint16_t a, uint2 b)
|
936 |
+
{
|
937 |
+
uint32_t s = h0_h0(a);
|
938 |
+
Float4_ fc;
|
939 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
940 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
941 |
+
return fc;
|
942 |
+
}
|
943 |
+
|
944 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
945 |
+
|
946 |
+
template<>
|
947 |
+
inline __device__ Float8_ mul(uint4 a, uint4 b)
|
948 |
+
{
|
949 |
+
Float8_ fc;
|
950 |
+
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
951 |
+
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
952 |
+
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
|
953 |
+
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
|
954 |
+
return fc;
|
955 |
+
}
|
956 |
+
|
957 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
958 |
+
|
959 |
+
template<>
|
960 |
+
inline __device__ Float8_ mul(uint16_t a, uint4 b)
|
961 |
+
{
|
962 |
+
uint32_t s = h0_h0(a);
|
963 |
+
Float8_ fc;
|
964 |
+
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
965 |
+
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
966 |
+
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
|
967 |
+
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
|
968 |
+
return fc;
|
969 |
+
}
|
970 |
+
|
971 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
972 |
+
|
973 |
+
#ifdef ENABLE_BF16
|
974 |
+
template<>
|
975 |
+
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b)
|
976 |
+
{
|
977 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
978 |
+
return __hmul(a, b);
|
979 |
+
#else
|
980 |
+
return bf16hmul(a, b);
|
981 |
+
#endif
|
982 |
+
}
|
983 |
+
|
984 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
985 |
+
|
986 |
+
template<>
|
987 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
|
988 |
+
{
|
989 |
+
return bf16hmul2(a, b);
|
990 |
+
}
|
991 |
+
|
992 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
993 |
+
|
994 |
+
template<>
|
995 |
+
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
|
996 |
+
{
|
997 |
+
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
998 |
+
}
|
999 |
+
|
1000 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1001 |
+
|
1002 |
+
template<>
|
1003 |
+
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b)
|
1004 |
+
{
|
1005 |
+
bf16_4_t c;
|
1006 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
1007 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
1008 |
+
return c;
|
1009 |
+
}
|
1010 |
+
|
1011 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1012 |
+
|
1013 |
+
template<>
|
1014 |
+
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
|
1015 |
+
{
|
1016 |
+
__nv_bfloat162 s = bf162bf162(a);
|
1017 |
+
bf16_4_t c;
|
1018 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
1019 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
1020 |
+
return c;
|
1021 |
+
}
|
1022 |
+
|
1023 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1024 |
+
|
1025 |
+
template<>
|
1026 |
+
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
|
1027 |
+
{
|
1028 |
+
bf16_8_t c;
|
1029 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
1030 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
1031 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
1032 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
1033 |
+
return c;
|
1034 |
+
}
|
1035 |
+
|
1036 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1037 |
+
|
1038 |
+
template<>
|
1039 |
+
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b)
|
1040 |
+
{
|
1041 |
+
__nv_bfloat162 s = bf162bf162(a);
|
1042 |
+
bf16_8_t c;
|
1043 |
+
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
1044 |
+
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
1045 |
+
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
1046 |
+
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
1047 |
+
return c;
|
1048 |
+
}
|
1049 |
+
|
1050 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1051 |
+
|
1052 |
+
template<>
|
1053 |
+
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b)
|
1054 |
+
{
|
1055 |
+
float fa = (float)a;
|
1056 |
+
float fb = (float)b;
|
1057 |
+
return fa * fb;
|
1058 |
+
}
|
1059 |
+
|
1060 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1061 |
+
|
1062 |
+
template<>
|
1063 |
+
inline __device__ float mul(__nv_bfloat16 a, float b)
|
1064 |
+
{
|
1065 |
+
return __bfloat162float(a) * b;
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1069 |
+
|
1070 |
+
template<>
|
1071 |
+
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b)
|
1072 |
+
{
|
1073 |
+
float2 fa = bf1622float2(a);
|
1074 |
+
float2 fb = bf1622float2(b);
|
1075 |
+
return mul<float2, float2, float2>(fa, fb);
|
1076 |
+
}
|
1077 |
+
|
1078 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1079 |
+
|
1080 |
+
template<>
|
1081 |
+
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b)
|
1082 |
+
{
|
1083 |
+
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
1084 |
+
}
|
1085 |
+
|
1086 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1087 |
+
|
1088 |
+
template<>
|
1089 |
+
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b)
|
1090 |
+
{
|
1091 |
+
Float4_ fc;
|
1092 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
1093 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
1094 |
+
return fc;
|
1095 |
+
}
|
1096 |
+
|
1097 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1098 |
+
|
1099 |
+
template<>
|
1100 |
+
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b)
|
1101 |
+
{
|
1102 |
+
__nv_bfloat162 s = bf162bf162(a);
|
1103 |
+
Float4_ fc;
|
1104 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
1105 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
1106 |
+
return fc;
|
1107 |
+
}
|
1108 |
+
|
1109 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1110 |
+
|
1111 |
+
template<>
|
1112 |
+
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b)
|
1113 |
+
{
|
1114 |
+
Float8_ fc;
|
1115 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
1116 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
1117 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
1118 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
1119 |
+
return fc;
|
1120 |
+
}
|
1121 |
+
|
1122 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1123 |
+
|
1124 |
+
template<>
|
1125 |
+
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b)
|
1126 |
+
{
|
1127 |
+
__nv_bfloat162 s = bf162bf162(a);
|
1128 |
+
Float8_ fc;
|
1129 |
+
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
1130 |
+
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
1131 |
+
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
1132 |
+
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
1133 |
+
return fc;
|
1134 |
+
}
|
1135 |
+
#endif // ENABLE_BF16
|
1136 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1137 |
+
|
1138 |
+
inline __device__ float sum(float v)
|
1139 |
+
{
|
1140 |
+
return v;
|
1141 |
+
}
|
1142 |
+
|
1143 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1144 |
+
|
1145 |
+
inline __device__ float sum(float2 v)
|
1146 |
+
{
|
1147 |
+
return v.x + v.y;
|
1148 |
+
}
|
1149 |
+
|
1150 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1151 |
+
|
1152 |
+
inline __device__ float sum(float4 v)
|
1153 |
+
{
|
1154 |
+
return v.x + v.y + v.z + v.w;
|
1155 |
+
}
|
1156 |
+
|
1157 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1158 |
+
|
1159 |
+
#ifdef ENABLE_BF16
|
1160 |
+
inline __device__ float sum(__nv_bfloat162 v)
|
1161 |
+
{
|
1162 |
+
float2 vf = bf1622float2(v);
|
1163 |
+
return vf.x + vf.y;
|
1164 |
+
}
|
1165 |
+
|
1166 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1167 |
+
|
1168 |
+
inline __device__ float sum(bf16_4_t v)
|
1169 |
+
{
|
1170 |
+
return sum(v.x) + sum(v.y);
|
1171 |
+
}
|
1172 |
+
|
1173 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1174 |
+
|
1175 |
+
inline __device__ float sum(bf16_8_t v)
|
1176 |
+
{
|
1177 |
+
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
1178 |
+
}
|
1179 |
+
#endif // ENABLE_BF16
|
1180 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1181 |
+
|
1182 |
+
inline __device__ float sum(uint16_t v)
|
1183 |
+
{
|
1184 |
+
return half_to_float(v);
|
1185 |
+
}
|
1186 |
+
|
1187 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1188 |
+
|
1189 |
+
inline __device__ float sum(uint32_t v)
|
1190 |
+
{
|
1191 |
+
float2 tmp = half2_to_float2(v);
|
1192 |
+
return tmp.x + tmp.y;
|
1193 |
+
}
|
1194 |
+
|
1195 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1196 |
+
|
1197 |
+
inline __device__ float sum(uint2 v)
|
1198 |
+
{
|
1199 |
+
uint32_t c = add(v.x, v.y);
|
1200 |
+
return sum(c);
|
1201 |
+
}
|
1202 |
+
|
1203 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1204 |
+
|
1205 |
+
inline __device__ float sum(uint4 v)
|
1206 |
+
{
|
1207 |
+
#if 1
|
1208 |
+
uint32_t c = add(v.x, v.y);
|
1209 |
+
c = add(c, v.z);
|
1210 |
+
c = add(c, v.w);
|
1211 |
+
#else
|
1212 |
+
uint32_t c = add(v.x, v.y);
|
1213 |
+
uint32_t d = add(v.z, v.w);
|
1214 |
+
c = add(c, d);
|
1215 |
+
#endif
|
1216 |
+
return sum(c);
|
1217 |
+
}
|
1218 |
+
|
1219 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1220 |
+
|
1221 |
+
inline __device__ float sum(Float4_ v)
|
1222 |
+
{
|
1223 |
+
return v.x.x + v.x.y + v.y.x + v.y.y;
|
1224 |
+
}
|
1225 |
+
|
1226 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1227 |
+
|
1228 |
+
inline __device__ float sum(Float8_ v)
|
1229 |
+
{
|
1230 |
+
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
1231 |
+
}
|
1232 |
+
|
1233 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1234 |
+
|
1235 |
+
template<typename T>
|
1236 |
+
inline __device__ float dot(T a, T b)
|
1237 |
+
{
|
1238 |
+
return sum(mul<T, T, T>(a, b));
|
1239 |
+
}
|
1240 |
+
|
1241 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1242 |
+
|
1243 |
+
template<typename A, typename T>
|
1244 |
+
inline __device__ float dot(T a, T b)
|
1245 |
+
{
|
1246 |
+
return sum(mul<A, T, T>(a, b));
|
1247 |
+
}
|
1248 |
+
|
1249 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1250 |
+
|
1251 |
+
inline __device__ void zero(uint16_t& dst)
|
1252 |
+
{
|
1253 |
+
dst = uint16_t(0);
|
1254 |
+
}
|
1255 |
+
|
1256 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1257 |
+
|
1258 |
+
template<typename T>
|
1259 |
+
inline __device__ void zero(T& dst)
|
1260 |
+
{
|
1261 |
+
constexpr int WORDS = sizeof(T) / 4;
|
1262 |
+
union {
|
1263 |
+
T raw;
|
1264 |
+
uint32_t words[WORDS];
|
1265 |
+
} tmp;
|
1266 |
+
#pragma unroll
|
1267 |
+
for (int ii = 0; ii < WORDS; ++ii) {
|
1268 |
+
tmp.words[ii] = 0u;
|
1269 |
+
}
|
1270 |
+
dst = tmp.raw;
|
1271 |
+
}
|
1272 |
+
|
1273 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1274 |
+
|
1275 |
+
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base)
|
1276 |
+
{
|
1277 |
+
const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
|
1278 |
+
return {cos(inv_freq), sin(inv_freq)};
|
1279 |
+
}
|
1280 |
+
|
1281 |
+
inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
|
1282 |
+
{
|
1283 |
+
float2 rot_v;
|
1284 |
+
rot_v.x = coef.x * v.x - coef.y * v.y;
|
1285 |
+
rot_v.y = coef.x * v.y + coef.y * v.x;
|
1286 |
+
return rot_v;
|
1287 |
+
}
|
1288 |
+
|
1289 |
+
inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef)
|
1290 |
+
{
|
1291 |
+
float2 fv = half2_to_float2(v);
|
1292 |
+
float2 rot_fv = rotary_embedding_transform(fv, coef);
|
1293 |
+
return float2_to_half2(rot_fv);
|
1294 |
+
}
|
1295 |
+
|
1296 |
+
#ifdef ENABLE_BF16
|
1297 |
+
inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef)
|
1298 |
+
{
|
1299 |
+
float2 fv = bf1622float2(v);
|
1300 |
+
float2 rot_fv = rotary_embedding_transform(fv, coef);
|
1301 |
+
return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
|
1302 |
+
}
|
1303 |
+
#endif
|
1304 |
+
|
1305 |
+
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1306 |
+
{
|
1307 |
+
return;
|
1308 |
+
}
|
1309 |
+
|
1310 |
+
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1311 |
+
{
|
1312 |
+
return;
|
1313 |
+
}
|
1314 |
+
|
1315 |
+
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1316 |
+
{
|
1317 |
+
if (2 * tid >= rot_embed_dim) {
|
1318 |
+
return;
|
1319 |
+
}
|
1320 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1321 |
+
q = rotary_embedding_transform(q, coef);
|
1322 |
+
}
|
1323 |
+
|
1324 |
+
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1325 |
+
{
|
1326 |
+
if (2 * tid >= rot_embed_dim) {
|
1327 |
+
return;
|
1328 |
+
}
|
1329 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1330 |
+
q = rotary_embedding_transform(q, coef);
|
1331 |
+
k = rotary_embedding_transform(k, coef);
|
1332 |
+
}
|
1333 |
+
|
1334 |
+
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1335 |
+
{
|
1336 |
+
if (4 * tid >= rot_embed_dim) {
|
1337 |
+
return;
|
1338 |
+
}
|
1339 |
+
|
1340 |
+
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
1341 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1342 |
+
q_.x = rotary_embedding_transform(q_.x, coef0);
|
1343 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1344 |
+
q_.y = rotary_embedding_transform(q_.y, coef1);
|
1345 |
+
}
|
1346 |
+
|
1347 |
+
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1348 |
+
{
|
1349 |
+
if (4 * tid >= rot_embed_dim) {
|
1350 |
+
return;
|
1351 |
+
}
|
1352 |
+
|
1353 |
+
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
1354 |
+
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
|
1355 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1356 |
+
q_.x = rotary_embedding_transform(q_.x, coef0);
|
1357 |
+
k_.x = rotary_embedding_transform(k_.x, coef0);
|
1358 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1359 |
+
q_.y = rotary_embedding_transform(q_.y, coef1);
|
1360 |
+
k_.y = rotary_embedding_transform(k_.y, coef1);
|
1361 |
+
}
|
1362 |
+
|
1363 |
+
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1364 |
+
{
|
1365 |
+
if (2 * tid >= rot_embed_dim) {
|
1366 |
+
return;
|
1367 |
+
}
|
1368 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1369 |
+
q = rotary_embedding_transform(q, coef);
|
1370 |
+
}
|
1371 |
+
|
1372 |
+
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1373 |
+
{
|
1374 |
+
if (2 * tid >= rot_embed_dim) {
|
1375 |
+
return;
|
1376 |
+
}
|
1377 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1378 |
+
q = rotary_embedding_transform(q, coef);
|
1379 |
+
k = rotary_embedding_transform(k, coef);
|
1380 |
+
}
|
1381 |
+
|
1382 |
+
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1383 |
+
{
|
1384 |
+
if (4 * tid >= rot_embed_dim) {
|
1385 |
+
return;
|
1386 |
+
}
|
1387 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1388 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1389 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1390 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1391 |
+
}
|
1392 |
+
|
1393 |
+
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1394 |
+
{
|
1395 |
+
if (4 * tid >= rot_embed_dim) {
|
1396 |
+
return;
|
1397 |
+
}
|
1398 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1399 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1400 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1401 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1402 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1403 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1404 |
+
}
|
1405 |
+
|
1406 |
+
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1407 |
+
{
|
1408 |
+
if (8 * tid >= rot_embed_dim) {
|
1409 |
+
return;
|
1410 |
+
}
|
1411 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
1412 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1413 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
1414 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1415 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
1416 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1417 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
1418 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1419 |
+
}
|
1420 |
+
|
1421 |
+
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1422 |
+
{
|
1423 |
+
if (8 * tid >= rot_embed_dim) {
|
1424 |
+
return;
|
1425 |
+
}
|
1426 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
1427 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1428 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1429 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
1430 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1431 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1432 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
1433 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1434 |
+
k.z = rotary_embedding_transform(k.z, coef2);
|
1435 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
1436 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1437 |
+
k.w = rotary_embedding_transform(k.w, coef3);
|
1438 |
+
}
|
1439 |
+
|
1440 |
+
#ifdef ENABLE_BF16
|
1441 |
+
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1442 |
+
{
|
1443 |
+
if (2 * tid >= rot_embed_dim) {
|
1444 |
+
return;
|
1445 |
+
}
|
1446 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1447 |
+
q = rotary_embedding_transform(q, coef);
|
1448 |
+
}
|
1449 |
+
|
1450 |
+
inline __device__ void
|
1451 |
+
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1452 |
+
{
|
1453 |
+
if (2 * tid >= rot_embed_dim) {
|
1454 |
+
return;
|
1455 |
+
}
|
1456 |
+
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
1457 |
+
q = rotary_embedding_transform(q, coef);
|
1458 |
+
k = rotary_embedding_transform(k, coef);
|
1459 |
+
}
|
1460 |
+
|
1461 |
+
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1462 |
+
{
|
1463 |
+
if (4 * tid >= rot_embed_dim) {
|
1464 |
+
return;
|
1465 |
+
}
|
1466 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1467 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1468 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1469 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1470 |
+
}
|
1471 |
+
|
1472 |
+
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1473 |
+
{
|
1474 |
+
if (4 * tid >= rot_embed_dim) {
|
1475 |
+
return;
|
1476 |
+
}
|
1477 |
+
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
1478 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1479 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1480 |
+
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
1481 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1482 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1483 |
+
}
|
1484 |
+
|
1485 |
+
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1486 |
+
{
|
1487 |
+
if (8 * tid >= rot_embed_dim) {
|
1488 |
+
return;
|
1489 |
+
}
|
1490 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
1491 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1492 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
1493 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1494 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
1495 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1496 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
1497 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1498 |
+
}
|
1499 |
+
|
1500 |
+
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
1501 |
+
{
|
1502 |
+
if (8 * tid >= rot_embed_dim) {
|
1503 |
+
return;
|
1504 |
+
}
|
1505 |
+
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
1506 |
+
q.x = rotary_embedding_transform(q.x, coef0);
|
1507 |
+
k.x = rotary_embedding_transform(k.x, coef0);
|
1508 |
+
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
1509 |
+
q.y = rotary_embedding_transform(q.y, coef1);
|
1510 |
+
k.y = rotary_embedding_transform(k.y, coef1);
|
1511 |
+
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
1512 |
+
q.z = rotary_embedding_transform(q.z, coef2);
|
1513 |
+
k.z = rotary_embedding_transform(k.z, coef2);
|
1514 |
+
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
1515 |
+
q.w = rotary_embedding_transform(q.w, coef3);
|
1516 |
+
k.w = rotary_embedding_transform(k.w, coef3);
|
1517 |
+
}
|
1518 |
+
#endif // ENABLE_BF16
|
1519 |
+
|
1520 |
+
template<typename Vec_T, typename T>
|
1521 |
+
__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
|
1522 |
+
|
1523 |
+
template<>
|
1524 |
+
__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch)
|
1525 |
+
{
|
1526 |
+
return;
|
1527 |
+
}
|
1528 |
+
|
1529 |
+
template<>
|
1530 |
+
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1531 |
+
{
|
1532 |
+
union {
|
1533 |
+
uint32_t u32;
|
1534 |
+
uint16_t u16[2];
|
1535 |
+
} tmp;
|
1536 |
+
tmp.u16[0] = smem[transpose_idx];
|
1537 |
+
tmp.u16[1] = smem[smem_pitch + transpose_idx];
|
1538 |
+
|
1539 |
+
vec = tmp.u32;
|
1540 |
+
}
|
1541 |
+
|
1542 |
+
template<>
|
1543 |
+
__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1544 |
+
{
|
1545 |
+
union {
|
1546 |
+
uint32_t u32;
|
1547 |
+
uint16_t u16[2];
|
1548 |
+
} tmp_1, tmp_2;
|
1549 |
+
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
|
1550 |
+
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
|
1551 |
+
|
1552 |
+
union {
|
1553 |
+
uint2 u32x2;
|
1554 |
+
uint16_t u16[4];
|
1555 |
+
} tmp_3;
|
1556 |
+
tmp_3.u16[0] = tmp_1.u16[0];
|
1557 |
+
tmp_3.u16[1] = tmp_2.u16[0];
|
1558 |
+
tmp_3.u16[2] = tmp_1.u16[1];
|
1559 |
+
tmp_3.u16[3] = tmp_2.u16[1];
|
1560 |
+
|
1561 |
+
vec = tmp_3.u32x2;
|
1562 |
+
}
|
1563 |
+
|
1564 |
+
template<>
|
1565 |
+
__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1566 |
+
{
|
1567 |
+
union {
|
1568 |
+
uint64_t u64;
|
1569 |
+
uint16_t u16[4];
|
1570 |
+
} tmp_1, tmp_2;
|
1571 |
+
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
|
1572 |
+
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
|
1573 |
+
|
1574 |
+
union {
|
1575 |
+
uint4 u32x4;
|
1576 |
+
uint16_t u16[8];
|
1577 |
+
} tmp_3;
|
1578 |
+
tmp_3.u16[0] = tmp_1.u16[0];
|
1579 |
+
tmp_3.u16[1] = tmp_2.u16[0];
|
1580 |
+
tmp_3.u16[2] = tmp_1.u16[1];
|
1581 |
+
tmp_3.u16[3] = tmp_2.u16[1];
|
1582 |
+
tmp_3.u16[4] = tmp_1.u16[2];
|
1583 |
+
tmp_3.u16[5] = tmp_2.u16[2];
|
1584 |
+
tmp_3.u16[6] = tmp_1.u16[3];
|
1585 |
+
tmp_3.u16[7] = tmp_2.u16[3];
|
1586 |
+
|
1587 |
+
vec = tmp_3.u32x4;
|
1588 |
+
}
|
1589 |
+
|
1590 |
+
#ifdef ENABLE_BF16
|
1591 |
+
template<>
|
1592 |
+
__device__ __inline__ void
|
1593 |
+
vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1594 |
+
{
|
1595 |
+
union {
|
1596 |
+
uint32_t u32;
|
1597 |
+
__nv_bfloat16 bf16[2];
|
1598 |
+
} tmp_1, tmp_2;
|
1599 |
+
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
|
1600 |
+
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
|
1601 |
+
|
1602 |
+
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
|
1603 |
+
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
|
1604 |
+
}
|
1605 |
+
|
1606 |
+
template<>
|
1607 |
+
__device__ __inline__ void
|
1608 |
+
vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1609 |
+
{
|
1610 |
+
union {
|
1611 |
+
uint64_t u64;
|
1612 |
+
__nv_bfloat16 bf16[4];
|
1613 |
+
} tmp_1, tmp_2;
|
1614 |
+
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
|
1615 |
+
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
|
1616 |
+
|
1617 |
+
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
|
1618 |
+
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
|
1619 |
+
vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]};
|
1620 |
+
vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]};
|
1621 |
+
}
|
1622 |
+
#endif // ENABLE_BF16
|
1623 |
+
|
1624 |
+
template<>
|
1625 |
+
__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch)
|
1626 |
+
{
|
1627 |
+
vec.x = smem[transpose_idx];
|
1628 |
+
vec.z = smem[transpose_idx + 1];
|
1629 |
+
vec.y = smem[smem_pitch + transpose_idx];
|
1630 |
+
vec.w = smem[smem_pitch + transpose_idx + 1];
|
1631 |
+
}
|
1632 |
+
|
1633 |
+
template<>
|
1634 |
+
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
|
1635 |
+
{
|
1636 |
+
union {
|
1637 |
+
uint32_t u32;
|
1638 |
+
half u16[2];
|
1639 |
+
} tmp;
|
1640 |
+
tmp.u16[0] = smem[transpose_idx];
|
1641 |
+
tmp.u16[1] = smem[smem_pitch + transpose_idx];
|
1642 |
+
|
1643 |
+
vec = tmp.u32;
|
1644 |
+
}
|
1645 |
+
|
1646 |
+
#ifdef ENABLE_BF16
|
1647 |
+
template<>
|
1648 |
+
__device__ __inline__ void
|
1649 |
+
vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1650 |
+
{
|
1651 |
+
vec.x = smem[transpose_idx];
|
1652 |
+
vec.y = smem[smem_pitch + transpose_idx];
|
1653 |
+
}
|
1654 |
+
#endif
|
1655 |
+
|
1656 |
+
template<>
|
1657 |
+
__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch)
|
1658 |
+
{
|
1659 |
+
vec.x = smem[transpose_idx];
|
1660 |
+
vec.y = smem[smem_pitch + transpose_idx];
|
1661 |
+
}
|
1662 |
+
|
1663 |
+
template<typename Vec_T, typename T>
|
1664 |
+
__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
|
1665 |
+
|
1666 |
+
template<>
|
1667 |
+
__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch)
|
1668 |
+
{
|
1669 |
+
return;
|
1670 |
+
}
|
1671 |
+
|
1672 |
+
template<>
|
1673 |
+
__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1674 |
+
{
|
1675 |
+
union {
|
1676 |
+
uint64_t u64;
|
1677 |
+
uint16_t u16[4];
|
1678 |
+
} tmp_1, tmp_2;
|
1679 |
+
|
1680 |
+
union {
|
1681 |
+
uint4 u32x4;
|
1682 |
+
uint16_t u16[8];
|
1683 |
+
} tmp_3;
|
1684 |
+
tmp_3.u32x4 = vec;
|
1685 |
+
tmp_1.u16[0] = tmp_3.u16[0];
|
1686 |
+
tmp_2.u16[0] = tmp_3.u16[1];
|
1687 |
+
tmp_1.u16[1] = tmp_3.u16[2];
|
1688 |
+
tmp_2.u16[1] = tmp_3.u16[3];
|
1689 |
+
tmp_1.u16[2] = tmp_3.u16[4];
|
1690 |
+
tmp_2.u16[2] = tmp_3.u16[5];
|
1691 |
+
tmp_1.u16[3] = tmp_3.u16[6];
|
1692 |
+
tmp_2.u16[3] = tmp_3.u16[7];
|
1693 |
+
|
1694 |
+
*reinterpret_cast<uint64_t*>(&smem[transpose_idx]) = tmp_1.u64;
|
1695 |
+
*reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u64;
|
1696 |
+
}
|
1697 |
+
|
1698 |
+
template<>
|
1699 |
+
__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1700 |
+
{
|
1701 |
+
union {
|
1702 |
+
uint32_t u32;
|
1703 |
+
uint16_t u16[2];
|
1704 |
+
} tmp_1, tmp_2;
|
1705 |
+
|
1706 |
+
union {
|
1707 |
+
uint2 u32x2;
|
1708 |
+
uint16_t u16[4];
|
1709 |
+
} tmp_3;
|
1710 |
+
tmp_3.u32x2 = vec;
|
1711 |
+
tmp_1.u16[0] = tmp_3.u16[0];
|
1712 |
+
tmp_2.u16[0] = tmp_3.u16[1];
|
1713 |
+
tmp_1.u16[1] = tmp_3.u16[2];
|
1714 |
+
tmp_2.u16[1] = tmp_3.u16[3];
|
1715 |
+
|
1716 |
+
*reinterpret_cast<uint32_t*>(&smem[transpose_idx]) = tmp_1.u32;
|
1717 |
+
*reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u32;
|
1718 |
+
}
|
1719 |
+
|
1720 |
+
template<>
|
1721 |
+
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
1722 |
+
{
|
1723 |
+
union {
|
1724 |
+
uint32_t u32;
|
1725 |
+
uint16_t u16[2];
|
1726 |
+
} tmp;
|
1727 |
+
tmp.u32 = vec;
|
1728 |
+
|
1729 |
+
smem[transpose_idx] = tmp.u16[0];
|
1730 |
+
smem[smem_pitch + transpose_idx] = tmp.u16[1];
|
1731 |
+
}
|
1732 |
+
|
1733 |
+
template<>
|
1734 |
+
__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch)
|
1735 |
+
{
|
1736 |
+
smem[transpose_idx] = vec.x;
|
1737 |
+
smem[transpose_idx + 1] = vec.z;
|
1738 |
+
smem[smem_pitch + transpose_idx] = vec.y;
|
1739 |
+
smem[smem_pitch + transpose_idx + 1] = vec.w;
|
1740 |
+
}
|
1741 |
+
|
1742 |
+
template<>
|
1743 |
+
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
|
1744 |
+
{
|
1745 |
+
union {
|
1746 |
+
uint32_t u32;
|
1747 |
+
half u16[2];
|
1748 |
+
} tmp;
|
1749 |
+
|
1750 |
+
tmp.u32 = vec;
|
1751 |
+
smem[transpose_idx] = tmp.u16[0];
|
1752 |
+
smem[smem_pitch + transpose_idx] = tmp.u16[1];
|
1753 |
+
}
|
1754 |
+
|
1755 |
+
#ifdef ENABLE_BF16
|
1756 |
+
template<>
|
1757 |
+
__device__ __inline__ void
|
1758 |
+
write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1759 |
+
{
|
1760 |
+
smem[transpose_idx] = vec.x;
|
1761 |
+
smem[smem_pitch + transpose_idx] = vec.y;
|
1762 |
+
}
|
1763 |
+
|
1764 |
+
template<>
|
1765 |
+
__device__ __inline__ void
|
1766 |
+
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1767 |
+
{
|
1768 |
+
write_smem_transpose(reinterpret_cast<const uint2&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
|
1769 |
+
}
|
1770 |
+
|
1771 |
+
template<>
|
1772 |
+
__device__ __inline__ void
|
1773 |
+
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
1774 |
+
{
|
1775 |
+
write_smem_transpose(reinterpret_cast<const uint4&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
|
1776 |
+
}
|
1777 |
+
#endif
|
1778 |
+
|
1779 |
+
template<>
|
1780 |
+
__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch)
|
1781 |
+
{
|
1782 |
+
smem[transpose_idx] = vec.x;
|
1783 |
+
smem[smem_pitch + transpose_idx] = vec.y;
|
1784 |
+
}
|
1785 |
+
|
1786 |
+
} // namespace mmha
|
AutoAWQ_kernels/awq_ext/attention/ft_attention.cpp
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from NVIDIA/FasterTransformer and FlashAttention
|
2 |
+
|
3 |
+
#include <torch/extension.h>
|
4 |
+
#include "ATen/cuda/CUDAContext.h"
|
5 |
+
#include <c10/cuda/CUDAGuard.h>
|
6 |
+
|
7 |
+
#include "ft_attention.h"
|
8 |
+
#include "decoder_masked_multihead_attention.h"
|
9 |
+
|
10 |
+
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
|
11 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
12 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
13 |
+
|
14 |
+
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
|
15 |
+
if (TYPE == at::ScalarType::Half) { \
|
16 |
+
using scalar_t = at::Half; \
|
17 |
+
__VA_ARGS__(); \
|
18 |
+
} else if (TYPE == at::ScalarType::BFloat16) { \
|
19 |
+
using scalar_t = at::BFloat16; \
|
20 |
+
__VA_ARGS__(); \
|
21 |
+
} else if (TYPE == at::ScalarType::Float) { \
|
22 |
+
using scalar_t = float; \
|
23 |
+
__VA_ARGS__(); \
|
24 |
+
} else { \
|
25 |
+
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
|
26 |
+
}
|
27 |
+
|
28 |
+
template<typename T>
|
29 |
+
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
|
30 |
+
const cudaStream_t& stream);
|
31 |
+
|
32 |
+
template<typename T>
|
33 |
+
void cross_multihead_attention(const Masked_multihead_attention_params<T>& params,
|
34 |
+
const cudaStream_t& stream);
|
35 |
+
|
36 |
+
template<typename T>
|
37 |
+
struct SATypeConverter {
|
38 |
+
using Type = T;
|
39 |
+
};
|
40 |
+
|
41 |
+
template<>
|
42 |
+
struct SATypeConverter<at::Half> {
|
43 |
+
using Type = uint16_t;
|
44 |
+
};
|
45 |
+
|
46 |
+
template<>
|
47 |
+
struct SATypeConverter<at::BFloat16> {
|
48 |
+
using Type = __nv_bfloat16;
|
49 |
+
};
|
50 |
+
|
51 |
+
template <typename T>
|
52 |
+
void set_params(Masked_multihead_attention_params<T> ¶ms,
|
53 |
+
const size_t batch_size,
|
54 |
+
const size_t nheads,
|
55 |
+
const size_t nheads_kv,
|
56 |
+
const size_t memory_max_seqlen,
|
57 |
+
const size_t headdim,
|
58 |
+
const int timestep,
|
59 |
+
const int rotary_embedding_dim,
|
60 |
+
const float rotary_base,
|
61 |
+
const bool neox_rotary_style,
|
62 |
+
const int qkv_batch_stride,
|
63 |
+
T *q_ptr,
|
64 |
+
T *k_ptr,
|
65 |
+
T *v_ptr,
|
66 |
+
T *k_cache_ptr,
|
67 |
+
T *v_cache_ptr,
|
68 |
+
int *length_per_sample,
|
69 |
+
float *alibi_slopes_ptr,
|
70 |
+
T *out_ptr) {
|
71 |
+
// Reset the parameters
|
72 |
+
memset(¶ms, 0, sizeof(params));
|
73 |
+
params.q = q_ptr;
|
74 |
+
params.k = k_ptr;
|
75 |
+
params.v = v_ptr;
|
76 |
+
params.q_bias = nullptr;
|
77 |
+
params.k_bias = nullptr;
|
78 |
+
params.v_bias = nullptr;
|
79 |
+
params.k_cache = k_cache_ptr;
|
80 |
+
params.v_cache = v_cache_ptr;
|
81 |
+
params.linear_bias_slopes = alibi_slopes_ptr;
|
82 |
+
params.out = out_ptr;
|
83 |
+
params.cache_indir = nullptr;
|
84 |
+
params.stride = qkv_batch_stride;
|
85 |
+
params.batch_size = batch_size;
|
86 |
+
params.beam_width = 1;
|
87 |
+
params.memory_max_len = memory_max_seqlen;
|
88 |
+
params.num_heads = nheads;
|
89 |
+
params.num_kv_heads = nheads_kv;
|
90 |
+
params.hidden_size_per_head = headdim;
|
91 |
+
params.rotary_embedding_dim = rotary_embedding_dim;
|
92 |
+
params.rotary_base = rotary_base;
|
93 |
+
params.neox_rotary_style = neox_rotary_style;
|
94 |
+
params.timestep = timestep;
|
95 |
+
params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
|
96 |
+
params.total_padding_tokens = nullptr;
|
97 |
+
params.masked_tokens = nullptr;
|
98 |
+
params.prefix_prompt_lengths = nullptr;
|
99 |
+
params.max_prefix_prompt_length = 0;
|
100 |
+
params.relative_attention_bias = nullptr;
|
101 |
+
params.relative_attention_bias_stride = 0;
|
102 |
+
params.cross_attention_out = nullptr;
|
103 |
+
params.max_decoder_seq_len = 0;
|
104 |
+
params.is_return_cross_attentions = false;
|
105 |
+
params.finished = nullptr;
|
106 |
+
params.memory_length_per_sample = nullptr;
|
107 |
+
params.length_per_sample = length_per_sample;
|
108 |
+
}
|
109 |
+
|
110 |
+
torch::Tensor single_query_attention(const torch::Tensor q,
|
111 |
+
const torch::Tensor k,
|
112 |
+
const torch::Tensor v,
|
113 |
+
torch::Tensor k_cache,
|
114 |
+
torch::Tensor v_cache,
|
115 |
+
c10::optional<const torch::Tensor> length_per_sample_,
|
116 |
+
c10::optional<const torch::Tensor> alibi_slopes_,
|
117 |
+
const int timestep,
|
118 |
+
const int rotary_embedding_dim,
|
119 |
+
const float rotary_base,
|
120 |
+
// neox_rotary_style = not interleaved
|
121 |
+
const bool neox_rotary_style) {
|
122 |
+
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
|
123 |
+
int batch_size = v_cache.size(0);
|
124 |
+
int nheads = q.size(1);
|
125 |
+
int nheads_kv = v_cache.size(1);
|
126 |
+
int memory_max_seqlen = v_cache.size(2);
|
127 |
+
int headdim = v_cache.size(3);
|
128 |
+
CHECK_SHAPE(q, batch_size, nheads, headdim);
|
129 |
+
CHECK_SHAPE(k, batch_size, nheads_kv, headdim);
|
130 |
+
CHECK_SHAPE(v, batch_size, nheads_kv, headdim);
|
131 |
+
CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim);
|
132 |
+
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
|
133 |
+
int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
|
134 |
+
CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize);
|
135 |
+
TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
|
136 |
+
TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
|
137 |
+
TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
|
138 |
+
// TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
|
139 |
+
CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
|
140 |
+
|
141 |
+
if (length_per_sample_.has_value()) {
|
142 |
+
auto length_per_sample = length_per_sample_.value();
|
143 |
+
CHECK_DEVICE(length_per_sample);
|
144 |
+
CHECK_SHAPE(length_per_sample, batch_size);
|
145 |
+
CHECK_CONTIGUOUS(length_per_sample);
|
146 |
+
TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
|
147 |
+
}
|
148 |
+
|
149 |
+
if (alibi_slopes_.has_value()) {
|
150 |
+
auto alibi_slopes = alibi_slopes_.value();
|
151 |
+
CHECK_DEVICE(alibi_slopes);
|
152 |
+
CHECK_SHAPE(alibi_slopes, nheads);
|
153 |
+
CHECK_CONTIGUOUS(alibi_slopes);
|
154 |
+
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32);
|
155 |
+
}
|
156 |
+
|
157 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
158 |
+
// Cast to char to avoid compiler warning about narrowing
|
159 |
+
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
160 |
+
|
161 |
+
torch::Tensor out = torch::empty_like(q);
|
162 |
+
|
163 |
+
DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
|
164 |
+
using DataType = typename SATypeConverter<scalar_t>::Type;
|
165 |
+
Masked_multihead_attention_params<DataType> params;
|
166 |
+
set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim,
|
167 |
+
timestep, rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
|
168 |
+
reinterpret_cast<DataType*>(q.data_ptr()),
|
169 |
+
reinterpret_cast<DataType*>(k.data_ptr()),
|
170 |
+
reinterpret_cast<DataType*>(v.data_ptr()),
|
171 |
+
reinterpret_cast<DataType*>(k_cache.data_ptr()),
|
172 |
+
reinterpret_cast<DataType*>(v_cache.data_ptr()),
|
173 |
+
length_per_sample_.has_value()
|
174 |
+
? length_per_sample_.value().data_ptr<int>() : nullptr,
|
175 |
+
alibi_slopes_.has_value()
|
176 |
+
? alibi_slopes_.value().data_ptr<float>(): nullptr,
|
177 |
+
reinterpret_cast<DataType*>(out.data_ptr()));
|
178 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
179 |
+
masked_multihead_attention(params, stream);
|
180 |
+
});
|
181 |
+
return out;
|
182 |
+
}
|
AutoAWQ_kernels/awq_ext/attention/ft_attention.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
|
5 |
+
torch::Tensor single_query_attention(const torch::Tensor q,
|
6 |
+
const torch::Tensor k,
|
7 |
+
const torch::Tensor v,
|
8 |
+
torch::Tensor k_cache,
|
9 |
+
torch::Tensor v_cache,
|
10 |
+
c10::optional<const torch::Tensor> length_per_sample_,
|
11 |
+
c10::optional<const torch::Tensor> alibi_slopes_,
|
12 |
+
const int timestep,
|
13 |
+
const int rotary_embedding_dim = 0,
|
14 |
+
const float rotary_base = 10000.0f,
|
15 |
+
const bool neox_rotary_style=true);
|
AutoAWQ_kernels/awq_ext/exllama/cu_compat.cuh
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#ifndef _cuda_compat_cuh
|
4 |
+
#define _cuda_compat_cuh
|
5 |
+
|
6 |
+
// atomicAdd for half types, to support CC < 7.x
|
7 |
+
|
8 |
+
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
9 |
+
{
|
10 |
+
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
11 |
+
unsigned int old = *address_as_ui;
|
12 |
+
unsigned int assumed;
|
13 |
+
|
14 |
+
do
|
15 |
+
{
|
16 |
+
assumed = old;
|
17 |
+
__half_raw hsum;
|
18 |
+
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
19 |
+
half tmpres = __hadd(hsum, val);
|
20 |
+
hsum = __half_raw(tmpres);
|
21 |
+
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
22 |
+
old = atomicCAS(address_as_ui, assumed, old);
|
23 |
+
}
|
24 |
+
while (assumed != old);
|
25 |
+
}
|
26 |
+
|
27 |
+
// atomicAdd for half2 types
|
28 |
+
|
29 |
+
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
30 |
+
{
|
31 |
+
unsigned int* address_as_ui = (unsigned int*)address;
|
32 |
+
unsigned int old = *address_as_ui;
|
33 |
+
unsigned int assumed;
|
34 |
+
do
|
35 |
+
{
|
36 |
+
assumed = old;
|
37 |
+
half2 old_val = *((half2*)&old);
|
38 |
+
half2 new_val = __hadd2(old_val, val);
|
39 |
+
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
40 |
+
}
|
41 |
+
while (assumed != old);
|
42 |
+
}
|
43 |
+
|
44 |
+
//
|
45 |
+
|
46 |
+
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
47 |
+
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
48 |
+
|
49 |
+
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
50 |
+
|
51 |
+
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
52 |
+
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
53 |
+
#endif
|
54 |
+
|
55 |
+
#endif
|
56 |
+
#endif
|
57 |
+
|
58 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllama/cuda_buffers.cu
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#define _cuda_buffers_cu
|
4 |
+
#include "cuda_buffers.cuh"
|
5 |
+
|
6 |
+
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
|
7 |
+
// __constant__ half2 q4_table[16][256];
|
8 |
+
// half2 q4_table_host[16][256];
|
9 |
+
// bool q4_table_init = false;
|
10 |
+
|
11 |
+
CudaBuffers::CudaBuffers
|
12 |
+
(
|
13 |
+
int _device,
|
14 |
+
int _temp_state_size,
|
15 |
+
half* _temp_state,
|
16 |
+
half* _temp_dq
|
17 |
+
) :
|
18 |
+
device(_device),
|
19 |
+
temp_state_size(_temp_state_size),
|
20 |
+
temp_state(_temp_state),
|
21 |
+
temp_dq(_temp_dq)
|
22 |
+
{
|
23 |
+
cudaSetDevice(_device);
|
24 |
+
|
25 |
+
cudaStreamCreate(&alt_stream_1);
|
26 |
+
cudaStreamCreate(&alt_stream_2);
|
27 |
+
cudaStreamCreate(&alt_stream_3);
|
28 |
+
cudaEventCreate(&alt_stream_1_done);
|
29 |
+
cudaEventCreate(&alt_stream_2_done);
|
30 |
+
cudaEventCreate(&alt_stream_3_done);
|
31 |
+
}
|
32 |
+
|
33 |
+
CudaBuffers::~CudaBuffers()
|
34 |
+
{
|
35 |
+
cudaStreamDestroy(alt_stream_1);
|
36 |
+
cudaStreamDestroy(alt_stream_2);
|
37 |
+
cudaStreamDestroy(alt_stream_3);
|
38 |
+
cudaEventDestroy(alt_stream_1_done);
|
39 |
+
cudaEventDestroy(alt_stream_2_done);
|
40 |
+
cudaEventDestroy(alt_stream_3_done);
|
41 |
+
}
|
42 |
+
|
43 |
+
CudaBuffers* get_buffers(const int device_index)
|
44 |
+
{
|
45 |
+
return g_buffers[device_index];
|
46 |
+
}
|
47 |
+
|
48 |
+
void prepare_buffers_cuda
|
49 |
+
(
|
50 |
+
int _device,
|
51 |
+
int _temp_state_size,
|
52 |
+
half* _temp_state,
|
53 |
+
half* _temp_dq
|
54 |
+
)
|
55 |
+
{
|
56 |
+
CudaBuffers* buffers = new CudaBuffers
|
57 |
+
(
|
58 |
+
_device,
|
59 |
+
_temp_state_size,
|
60 |
+
_temp_state,
|
61 |
+
_temp_dq
|
62 |
+
);
|
63 |
+
|
64 |
+
g_buffers[_device] = buffers;
|
65 |
+
}
|
66 |
+
|
67 |
+
void cleanup_buffers_cuda()
|
68 |
+
{
|
69 |
+
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
|
70 |
+
{
|
71 |
+
if (!g_buffers[i]) continue;
|
72 |
+
delete g_buffers[i];
|
73 |
+
g_buffers[i] = NULL;
|
74 |
+
}
|
75 |
+
}
|
AutoAWQ_kernels/awq_ext/exllama/cuda_buffers.cuh
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#ifndef _cuda_buffers_cuh
|
4 |
+
#define _cuda_buffers_cuh
|
5 |
+
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
#include <cuda_fp16.h>
|
8 |
+
#include <cstdint>
|
9 |
+
#include <cstdio>
|
10 |
+
|
11 |
+
const int CUDA_MAX_DEVICES = 16;
|
12 |
+
|
13 |
+
// #ifndef _cuda_buffers_cu
|
14 |
+
// extern __constant__ half2 q4_table[16][256];
|
15 |
+
// #endif
|
16 |
+
|
17 |
+
class CudaBuffers
|
18 |
+
{
|
19 |
+
public:
|
20 |
+
int device;
|
21 |
+
|
22 |
+
half* temp_state; // [max_hidden_rows * intermediate_size]
|
23 |
+
int temp_state_size;
|
24 |
+
half* temp_dq; // size of largest quant tensor * 8
|
25 |
+
|
26 |
+
cudaStream_t alt_stream_1;
|
27 |
+
cudaStream_t alt_stream_2;
|
28 |
+
cudaStream_t alt_stream_3;
|
29 |
+
cudaEvent_t alt_stream_1_done;
|
30 |
+
cudaEvent_t alt_stream_2_done;
|
31 |
+
cudaEvent_t alt_stream_3_done;
|
32 |
+
|
33 |
+
CudaBuffers
|
34 |
+
(
|
35 |
+
int _device,
|
36 |
+
int _temp_state_size,
|
37 |
+
half* _temp_state,
|
38 |
+
half* _temp_dq
|
39 |
+
);
|
40 |
+
~CudaBuffers();
|
41 |
+
};
|
42 |
+
|
43 |
+
CudaBuffers* get_buffers(const int device_index);
|
44 |
+
|
45 |
+
void prepare_buffers_cuda
|
46 |
+
(
|
47 |
+
int _device,
|
48 |
+
int _temp_state_size,
|
49 |
+
half* _temp_state,
|
50 |
+
half* _temp_dq
|
51 |
+
);
|
52 |
+
|
53 |
+
void cleanup_buffers_cuda();
|
54 |
+
|
55 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllama/cuda_func/column_remap.cu
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#include "column_remap.cuh"
|
4 |
+
#include "../util.cuh"
|
5 |
+
|
6 |
+
const int SHUF_BLOCKSIZE_X = 256;
|
7 |
+
const int SHUF_BLOCKSIZE_Y = 16;
|
8 |
+
|
9 |
+
__global__ void column_remap_kernel
|
10 |
+
(
|
11 |
+
const half* __restrict__ x,
|
12 |
+
half* __restrict__ x_new,
|
13 |
+
const int x_width,
|
14 |
+
const int x_height,
|
15 |
+
const uint32_t* x_map
|
16 |
+
)
|
17 |
+
{
|
18 |
+
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
19 |
+
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
|
20 |
+
if (x_column >= x_width) return;
|
21 |
+
//if (x_row >= x_height) return;
|
22 |
+
|
23 |
+
int x_stride = x_width;
|
24 |
+
int x_idx = x_row * x_stride + x_column;
|
25 |
+
|
26 |
+
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
|
27 |
+
int x_idx_end = x_row_end * x_stride + x_column;
|
28 |
+
|
29 |
+
int s_column = x_map[x_column];
|
30 |
+
int s_idx = x_row * x_stride + s_column;
|
31 |
+
|
32 |
+
while (x_idx < x_idx_end)
|
33 |
+
{
|
34 |
+
x_new[x_idx] = x[s_idx];
|
35 |
+
x_idx += x_stride;
|
36 |
+
s_idx += x_stride;
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
// Remap columns in x to correspond to sequential group index before matmul
|
41 |
+
//
|
42 |
+
// perform x -> seq_x such that seq_x @ seq_w == x @ w
|
43 |
+
|
44 |
+
void column_remap_cuda
|
45 |
+
(
|
46 |
+
const half* x,
|
47 |
+
half* x_new,
|
48 |
+
const int x_height,
|
49 |
+
const int x_width,
|
50 |
+
const uint32_t* x_map
|
51 |
+
)
|
52 |
+
{
|
53 |
+
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
|
54 |
+
|
55 |
+
dim3 blocks
|
56 |
+
(
|
57 |
+
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
|
58 |
+
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
|
59 |
+
1
|
60 |
+
);
|
61 |
+
|
62 |
+
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
|
63 |
+
}
|
AutoAWQ_kernels/awq_ext/exllama/cuda_func/column_remap.cuh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#ifndef _column_remap_cuh
|
4 |
+
#define _column_remap_cuh
|
5 |
+
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
#include <cuda_fp16.h>
|
8 |
+
#include <cstdint>
|
9 |
+
|
10 |
+
void column_remap_cuda
|
11 |
+
(
|
12 |
+
const half* x,
|
13 |
+
half* x_new,
|
14 |
+
const int x_height,
|
15 |
+
const int x_width,
|
16 |
+
const uint32_t* x_map
|
17 |
+
);
|
18 |
+
|
19 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matmul.cu
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#include "q4_matmul.cuh"
|
4 |
+
#include "column_remap.cuh"
|
5 |
+
#include "../util.cuh"
|
6 |
+
#include "../matrix.cuh"
|
7 |
+
#include "../cu_compat.cuh"
|
8 |
+
#include "../cuda_buffers.cuh"
|
9 |
+
#if defined(USE_ROCM)
|
10 |
+
#include "../hip_compat.cuh"
|
11 |
+
#endif
|
12 |
+
|
13 |
+
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
14 |
+
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
15 |
+
|
16 |
+
typedef void (*fp_q4_matmul_kernel)
|
17 |
+
(
|
18 |
+
const half*,
|
19 |
+
const uint32_t*,
|
20 |
+
half*,
|
21 |
+
const half*,
|
22 |
+
const uint32_t*,
|
23 |
+
const int,
|
24 |
+
const int,
|
25 |
+
const int,
|
26 |
+
const int,
|
27 |
+
const int,
|
28 |
+
const uint32_t*,
|
29 |
+
bool
|
30 |
+
);
|
31 |
+
|
32 |
+
template<bool use_half2, bool use_groupsize, bool use_x_map>
|
33 |
+
__global__ void q4_matmul_kernel
|
34 |
+
(
|
35 |
+
const half* __restrict__ x,
|
36 |
+
const uint32_t* __restrict__ w,
|
37 |
+
half* __restrict__ out,
|
38 |
+
const half* __restrict__ w_scales,
|
39 |
+
const uint32_t* __restrict__ w_zeros,
|
40 |
+
const int height,
|
41 |
+
const int dim,
|
42 |
+
const int width,
|
43 |
+
const int groupsize,
|
44 |
+
const int block_size_z,
|
45 |
+
const uint32_t* __restrict__ x_map,
|
46 |
+
bool no_zero
|
47 |
+
)
|
48 |
+
{
|
49 |
+
// Start of block
|
50 |
+
|
51 |
+
int x_column = block_size_z * blockIdx.z;
|
52 |
+
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
53 |
+
|
54 |
+
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
55 |
+
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
56 |
+
|
57 |
+
int iterations = (x_column_end - x_column) / 8;
|
58 |
+
|
59 |
+
// Views
|
60 |
+
|
61 |
+
MatrixView_half x_(x, height, dim);
|
62 |
+
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
|
63 |
+
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
|
64 |
+
MatrixView_q4_column w_(w, dim, width);
|
65 |
+
MatrixView_half_rw out_(out, height, width);
|
66 |
+
|
67 |
+
// Zero output
|
68 |
+
|
69 |
+
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
70 |
+
{
|
71 |
+
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
72 |
+
__syncthreads();
|
73 |
+
}
|
74 |
+
|
75 |
+
// Loop over part of x row (and w column)
|
76 |
+
|
77 |
+
half2 acc = {};
|
78 |
+
half acc_h = {};
|
79 |
+
|
80 |
+
if constexpr (use_groupsize)
|
81 |
+
{
|
82 |
+
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
|
83 |
+
// could be slightly faster
|
84 |
+
|
85 |
+
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
86 |
+
{
|
87 |
+
if constexpr (use_half2)
|
88 |
+
{
|
89 |
+
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
90 |
+
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;
|
91 |
+
|
92 |
+
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
93 |
+
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
94 |
+
}
|
95 |
+
else
|
96 |
+
{
|
97 |
+
half w_scale = w_scales_.item(group, w_column);
|
98 |
+
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;
|
99 |
+
|
100 |
+
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
101 |
+
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
102 |
+
}
|
103 |
+
}
|
104 |
+
}
|
105 |
+
else
|
106 |
+
{
|
107 |
+
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
|
108 |
+
|
109 |
+
for (int k = x_column; k < x_column + iterations * 8; k += 8)
|
110 |
+
{
|
111 |
+
if constexpr (use_half2)
|
112 |
+
{
|
113 |
+
int group = k / groupsize;
|
114 |
+
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
115 |
+
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;
|
116 |
+
|
117 |
+
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
118 |
+
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
119 |
+
}
|
120 |
+
else
|
121 |
+
{
|
122 |
+
int group = k / groupsize;
|
123 |
+
half w_scale = w_scales_.item(group, w_column);
|
124 |
+
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0f;
|
125 |
+
|
126 |
+
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
127 |
+
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
128 |
+
}
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
// Add to block result
|
133 |
+
|
134 |
+
if constexpr (use_half2)
|
135 |
+
{
|
136 |
+
half result = __hadd(__low2half(acc), __high2half(acc));
|
137 |
+
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
138 |
+
}
|
139 |
+
else
|
140 |
+
{
|
141 |
+
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
|
146 |
+
{
|
147 |
+
// <bool use_half2, bool use_groupsize, bool use_x_map>
|
148 |
+
if (tuningParams->matmul_no_half2) {
|
149 |
+
if (block_size_z % groupsize == 0) {
|
150 |
+
if (x_map) return q4_matmul_kernel<false, true, true >;
|
151 |
+
else return q4_matmul_kernel<false, true, false>;
|
152 |
+
} else {
|
153 |
+
if (x_map) return q4_matmul_kernel<false, false, true >;
|
154 |
+
else return q4_matmul_kernel<false, false, false>;
|
155 |
+
}
|
156 |
+
} else {
|
157 |
+
if (block_size_z % groupsize == 0)
|
158 |
+
{
|
159 |
+
if (x_map) return q4_matmul_kernel<true, true, true >;
|
160 |
+
else return q4_matmul_kernel<true, true, false>;
|
161 |
+
} else {
|
162 |
+
if (x_map) return q4_matmul_kernel<true, false, true >;
|
163 |
+
else return q4_matmul_kernel<true, false, false>;
|
164 |
+
}
|
165 |
+
}
|
166 |
+
};
|
167 |
+
|
168 |
+
// Compute y = x @ w
|
169 |
+
|
170 |
+
void q4_matmul_cuda
|
171 |
+
(
|
172 |
+
ExLlamaTuning* tuningParams,
|
173 |
+
const half* x,
|
174 |
+
const int x_height,
|
175 |
+
const Q4Matrix* w,
|
176 |
+
half* out,
|
177 |
+
bool no_zero,
|
178 |
+
cudaStream_t alt_stream
|
179 |
+
)
|
180 |
+
{
|
181 |
+
int height = x_height;
|
182 |
+
int dim = w->height;
|
183 |
+
int width = w->width;
|
184 |
+
|
185 |
+
cudaSetDevice(w->device);
|
186 |
+
|
187 |
+
uint32_t* x_map = w->cuda_x_map;
|
188 |
+
const half* x_mapped = x;
|
189 |
+
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
|
190 |
+
{
|
191 |
+
CudaBuffers* buffers = get_buffers(w->device);
|
192 |
+
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
193 |
+
x_mapped = buffers->temp_state;
|
194 |
+
x_map = NULL;
|
195 |
+
}
|
196 |
+
|
197 |
+
int block_size_z;
|
198 |
+
if (w->width == 4096) block_size_z = 384; // 7B
|
199 |
+
else if (w->width == 11008) block_size_z = 256;
|
200 |
+
else if (w->width == 5120) block_size_z = 384; // 13B
|
201 |
+
else if (w->width == 13824) block_size_z = 256;
|
202 |
+
else if (w->width == 6656) block_size_z = 256; // 33B
|
203 |
+
else if (w->width == 17920) block_size_z = 128;
|
204 |
+
else block_size_z = 256;
|
205 |
+
|
206 |
+
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
|
207 |
+
|
208 |
+
dim3 threads(THREADS_X, THREADS_Y, 1);
|
209 |
+
|
210 |
+
dim3 blocks
|
211 |
+
(
|
212 |
+
(width + threads.x - 1) / threads.x,
|
213 |
+
(height + threads.y - 1) / threads.y,
|
214 |
+
(dim + block_size_z - 1) / block_size_z
|
215 |
+
);
|
216 |
+
|
217 |
+
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
218 |
+
|
219 |
+
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
220 |
+
}
|
221 |
+
|
222 |
+
void q4_matmul_recons_cuda
|
223 |
+
(
|
224 |
+
ExLlamaTuning* tuningParams,
|
225 |
+
const half* x,
|
226 |
+
const int x_height,
|
227 |
+
Q4Matrix* w,
|
228 |
+
half* out,
|
229 |
+
const cublasHandle_t handle,
|
230 |
+
bool no_zero
|
231 |
+
)
|
232 |
+
{
|
233 |
+
int height = x_height;
|
234 |
+
int dim = w->height;
|
235 |
+
int width = w->width;
|
236 |
+
|
237 |
+
cudaSetDevice(w->device);
|
238 |
+
CudaBuffers* buffers = get_buffers(w->device);
|
239 |
+
|
240 |
+
const half* x_mapped = x;
|
241 |
+
if (w->cuda_x_map)
|
242 |
+
{
|
243 |
+
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend for GPTQ with act-order. Please call the exllama_set_max_input_length function to increase the buffer size for a sequence length >=", x_height, ":\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, max_input_length=", x_height, ")");
|
244 |
+
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
245 |
+
x_mapped = buffers->temp_state;
|
246 |
+
}
|
247 |
+
|
248 |
+
w->reconstruct(buffers->temp_dq);
|
249 |
+
|
250 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
|
251 |
+
const float alpha = 1.0f;
|
252 |
+
const float beta = no_zero ? 1.0f : 0.0f;
|
253 |
+
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
254 |
+
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
255 |
+
#else
|
256 |
+
const half alpha = __float2half(1.0f);
|
257 |
+
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
258 |
+
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
259 |
+
#endif
|
260 |
+
}
|
AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matmul.cuh
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#ifndef _q4_matmul_cuh
|
4 |
+
#define _q4_matmul_cuh
|
5 |
+
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
#include <cuda_fp16.h>
|
8 |
+
#include <cstdint>
|
9 |
+
#include <cstdio>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
|
12 |
+
#include "q4_matrix.cuh"
|
13 |
+
#include "../tuning.h"
|
14 |
+
|
15 |
+
// Workaround for hipify_python using rocblas instead of hipblas.
|
16 |
+
#if defined(USE_ROCM)
|
17 |
+
#include <hipblas/hipblas.h>
|
18 |
+
#define rocblas_handle hipblasHandle_t
|
19 |
+
#endif
|
20 |
+
|
21 |
+
void q4_matmul_cuda
|
22 |
+
(
|
23 |
+
ExLlamaTuning* tuningParams,
|
24 |
+
const half* x,
|
25 |
+
const int x_height,
|
26 |
+
const Q4Matrix* w,
|
27 |
+
half* out,
|
28 |
+
bool no_zero = false,
|
29 |
+
cudaStream_t alt_stream = NULL
|
30 |
+
);
|
31 |
+
|
32 |
+
void q4_matmul_recons_cuda
|
33 |
+
(
|
34 |
+
ExLlamaTuning* tuningParams,
|
35 |
+
const half* x,
|
36 |
+
const int x_height,
|
37 |
+
Q4Matrix* w,
|
38 |
+
half* out,
|
39 |
+
const cublasHandle_t handle,
|
40 |
+
bool no_zero = false
|
41 |
+
);
|
42 |
+
|
43 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matrix.cu
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#include "q4_matrix.cuh"
|
4 |
+
#include <vector>
|
5 |
+
#include "../util.cuh"
|
6 |
+
#include "../matrix.cuh"
|
7 |
+
|
8 |
+
using namespace std;
|
9 |
+
|
10 |
+
const int UNSHUF_BLOCKSIZE_X = 64;
|
11 |
+
|
12 |
+
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
|
13 |
+
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
|
14 |
+
|
15 |
+
vector<Q4Matrix*> g_q4_matrices;
|
16 |
+
|
17 |
+
void g_q4_keep_matrix(Q4Matrix* m)
|
18 |
+
{
|
19 |
+
g_q4_matrices.push_back(m);
|
20 |
+
}
|
21 |
+
|
22 |
+
void g_q4_free_matrices()
|
23 |
+
{
|
24 |
+
for (const auto& m : g_q4_matrices) delete m;
|
25 |
+
g_q4_matrices.clear();
|
26 |
+
}
|
27 |
+
|
28 |
+
Q4Matrix::Q4Matrix
|
29 |
+
(
|
30 |
+
const int _height,
|
31 |
+
const int _width,
|
32 |
+
const int _groups,
|
33 |
+
|
34 |
+
uint32_t* _qweight,
|
35 |
+
uint32_t* _qzeros,
|
36 |
+
half* _scales,
|
37 |
+
uint32_t* _g_idx,
|
38 |
+
|
39 |
+
const int _device
|
40 |
+
) :
|
41 |
+
height(_height),
|
42 |
+
width(_width),
|
43 |
+
groups(_groups),
|
44 |
+
device(_device)
|
45 |
+
{
|
46 |
+
cudaSetDevice(device);
|
47 |
+
|
48 |
+
cuda_qweight = _qweight;
|
49 |
+
cuda_qzeros = _qzeros;
|
50 |
+
cuda_scales = _scales;
|
51 |
+
|
52 |
+
groupsize = height / groups;
|
53 |
+
|
54 |
+
if (_g_idx) make_sequential(_g_idx);
|
55 |
+
}
|
56 |
+
|
57 |
+
Q4Matrix::~Q4Matrix()
|
58 |
+
{
|
59 |
+
}
|
60 |
+
|
61 |
+
// Make sequential
|
62 |
+
|
63 |
+
__global__ void make_sequential_kernel
|
64 |
+
(
|
65 |
+
const uint32_t* __restrict__ w,
|
66 |
+
uint32_t* __restrict__ w_new,
|
67 |
+
const uint32_t* __restrict__ x_map,
|
68 |
+
const int w_height,
|
69 |
+
const int w_width
|
70 |
+
)
|
71 |
+
{
|
72 |
+
const uint64_t* w2 = (uint64_t*) w;
|
73 |
+
uint64_t* w_new2 = (uint64_t*) w_new;
|
74 |
+
int w2_stride = w_width >> 1;
|
75 |
+
|
76 |
+
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
77 |
+
if (w2_column >= w2_stride) return;
|
78 |
+
|
79 |
+
int w_new2_row = blockIdx.y;
|
80 |
+
|
81 |
+
int x_map_idx = w_new2_row << 3;
|
82 |
+
|
83 |
+
uint64_t dst = 0;
|
84 |
+
|
85 |
+
#pragma unroll
|
86 |
+
for (int i = 0; i < 8; i++)
|
87 |
+
{
|
88 |
+
int source_row = x_map[x_map_idx++];
|
89 |
+
|
90 |
+
int w2_row = source_row >> 3;
|
91 |
+
int w2_subrow = source_row & 0x07;
|
92 |
+
int w2_row_shift = w2_subrow << 2;
|
93 |
+
int wnew2_row_shift = i << 2;
|
94 |
+
|
95 |
+
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
96 |
+
src >>= w2_row_shift;
|
97 |
+
src &= 0x0000000f0000000f;
|
98 |
+
src <<= wnew2_row_shift;
|
99 |
+
dst |= src;
|
100 |
+
}
|
101 |
+
|
102 |
+
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
103 |
+
}
|
104 |
+
|
105 |
+
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
|
106 |
+
{
|
107 |
+
uint32_t* cuda_new_qweight = NULL;
|
108 |
+
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
109 |
+
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
|
110 |
+
|
111 |
+
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
112 |
+
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
113 |
+
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
114 |
+
|
115 |
+
// Group histogram
|
116 |
+
|
117 |
+
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
118 |
+
|
119 |
+
// Group map
|
120 |
+
|
121 |
+
for (int i = 0, acc = 0; i < groups; i++)
|
122 |
+
{
|
123 |
+
short tmp = cpu_g_idx_map[i];
|
124 |
+
cpu_g_idx_map[i] = acc;
|
125 |
+
acc += tmp;
|
126 |
+
}
|
127 |
+
|
128 |
+
// X map (inverse)
|
129 |
+
|
130 |
+
for (int row = 0; row < height; row++)
|
131 |
+
{
|
132 |
+
uint32_t target_group = cpu_g_idx[row];
|
133 |
+
uint32_t target_row = cpu_g_idx_map[target_group];
|
134 |
+
cpu_g_idx_map[target_group]++;
|
135 |
+
cpu_x_map_inv[row] = target_row;
|
136 |
+
}
|
137 |
+
|
138 |
+
// X map
|
139 |
+
|
140 |
+
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
141 |
+
|
142 |
+
// Move to CUDA
|
143 |
+
|
144 |
+
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
145 |
+
|
146 |
+
// Rearrange rows in w
|
147 |
+
|
148 |
+
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
149 |
+
dim3 blocks
|
150 |
+
(
|
151 |
+
(width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
|
152 |
+
height / 8,
|
153 |
+
1
|
154 |
+
);
|
155 |
+
|
156 |
+
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
157 |
+
|
158 |
+
// Replace qweights
|
159 |
+
|
160 |
+
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
161 |
+
|
162 |
+
// Cleanup
|
163 |
+
|
164 |
+
cudaDeviceSynchronize();
|
165 |
+
cudaFree(cuda_new_qweight);
|
166 |
+
free(cpu_g_idx_map);
|
167 |
+
free(cpu_x_map);
|
168 |
+
free(cpu_x_map_inv);
|
169 |
+
}
|
170 |
+
|
171 |
+
__global__ void reconstruct_kernel
|
172 |
+
(
|
173 |
+
const uint32_t* __restrict__ w,
|
174 |
+
half* __restrict__ out, // (y)
|
175 |
+
const half* __restrict__ w_scales,
|
176 |
+
const uint32_t* __restrict__ w_zeros,
|
177 |
+
const int height,
|
178 |
+
const int width,
|
179 |
+
const int groupsize
|
180 |
+
)
|
181 |
+
{
|
182 |
+
// Start of block
|
183 |
+
|
184 |
+
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
|
185 |
+
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
|
186 |
+
if (column >= width) return;
|
187 |
+
|
188 |
+
// Views
|
189 |
+
|
190 |
+
MatrixView_q4_column w_(w, height, width);
|
191 |
+
MatrixView_half_rw out_(out, height, width);
|
192 |
+
MatrixView_half w_scales_(w_scales, height / groupsize, width);
|
193 |
+
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
|
194 |
+
|
195 |
+
// Groupsize version
|
196 |
+
|
197 |
+
int group = row / groupsize;
|
198 |
+
|
199 |
+
half w_scale = w_scales_.item(group, column);
|
200 |
+
|
201 |
+
//
|
202 |
+
uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0f;
|
203 |
+
|
204 |
+
uint32_t w_read = w_.item_uint32_t(row, column);
|
205 |
+
half* out_ptr = out_.item_ptr(row, column);
|
206 |
+
|
207 |
+
#pragma unroll
|
208 |
+
for (int s = 0; s < 32; s += 4)
|
209 |
+
{
|
210 |
+
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
211 |
+
*out_ptr = w_item; out_ptr += out_.width;
|
212 |
+
}
|
213 |
+
}
|
214 |
+
|
215 |
+
void Q4Matrix::reconstruct(half* out)
|
216 |
+
{
|
217 |
+
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
|
218 |
+
|
219 |
+
dim3 blocks
|
220 |
+
(
|
221 |
+
(width + threads.x - 1) / threads.x,
|
222 |
+
(height / 8 + threads.y - 1) / threads.y,
|
223 |
+
1
|
224 |
+
);
|
225 |
+
|
226 |
+
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
227 |
+
}
|
AutoAWQ_kernels/awq_ext/exllama/cuda_func/q4_matrix.cuh
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#ifndef _q4_matrix_cuh
|
4 |
+
#define _q4_matrix_cuh
|
5 |
+
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
#include <cuda_fp16.h>
|
8 |
+
#include <cstdint>
|
9 |
+
|
10 |
+
class Q4Matrix
|
11 |
+
{
|
12 |
+
public:
|
13 |
+
|
14 |
+
int device;
|
15 |
+
|
16 |
+
int height;
|
17 |
+
int width;
|
18 |
+
int groups;
|
19 |
+
int groupsize;
|
20 |
+
|
21 |
+
uint32_t* cuda_qweight = NULL;
|
22 |
+
uint32_t* cuda_qzeros = NULL;
|
23 |
+
half* cuda_scales = NULL;
|
24 |
+
uint32_t* cuda_x_map = NULL;
|
25 |
+
|
26 |
+
Q4Matrix
|
27 |
+
(
|
28 |
+
const int _height,
|
29 |
+
const int _width,
|
30 |
+
const int _groups,
|
31 |
+
|
32 |
+
uint32_t* _qweight,
|
33 |
+
uint32_t* _qzeros,
|
34 |
+
half* _scales,
|
35 |
+
uint32_t* _g_idx,
|
36 |
+
|
37 |
+
const int _device
|
38 |
+
);
|
39 |
+
|
40 |
+
~Q4Matrix();
|
41 |
+
|
42 |
+
void reconstruct(half* out);
|
43 |
+
|
44 |
+
private:
|
45 |
+
|
46 |
+
void make_sequential(const uint32_t* cpu_g_idx);
|
47 |
+
|
48 |
+
};
|
49 |
+
|
50 |
+
void g_q4_keep_matrix(Q4Matrix* m);
|
51 |
+
void g_q4_free_matrices();
|
52 |
+
|
53 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllama/exllama_ext.cpp
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#include <torch/extension.h>
|
4 |
+
#include <c10/cuda/CUDAGuard.h>
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
#include <cuda_fp16.h>
|
8 |
+
#include <cstdint>
|
9 |
+
#include <cstdio>
|
10 |
+
#include "util.cuh"
|
11 |
+
#include "tuning.h"
|
12 |
+
#include "cuda_buffers.cuh"
|
13 |
+
#include "cuda_func/q4_matrix.cuh"
|
14 |
+
#include "cuda_func/q4_matmul.cuh"
|
15 |
+
#include "cuda_func/column_remap.cuh"
|
16 |
+
|
17 |
+
#include <typeinfo>
|
18 |
+
#include <limits>
|
19 |
+
#include <algorithm>
|
20 |
+
|
21 |
+
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
22 |
+
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
23 |
+
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
|
24 |
+
|
25 |
+
void check_cuda(cudaError_t ret)
|
26 |
+
{
|
27 |
+
switch (ret)
|
28 |
+
{
|
29 |
+
case cudaSuccess:
|
30 |
+
break;
|
31 |
+
|
32 |
+
case cudaUnspecified:
|
33 |
+
printf(" **** Unspecified error\n");
|
34 |
+
TORCH_CHECK(false, "CUDA error");
|
35 |
+
break;
|
36 |
+
|
37 |
+
default:
|
38 |
+
printf(" **** CUDA error\n"); \
|
39 |
+
printf(" **** %s\n", cudaGetErrorString(ret)); \
|
40 |
+
TORCH_CHECK(false, "CUDA error"); \
|
41 |
+
break;
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
// Some decluttering macros
|
46 |
+
|
47 |
+
#define STRINGIFY_(__x) #__x
|
48 |
+
#define STRINGIFY(__x) STRINGIFY_(__x)
|
49 |
+
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
50 |
+
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
51 |
+
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
52 |
+
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
53 |
+
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
54 |
+
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
55 |
+
|
56 |
+
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
57 |
+
do { \
|
58 |
+
TORCH_CHECK(__index >= 0, "no device index"); \
|
59 |
+
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
|
60 |
+
} while(0)
|
61 |
+
|
62 |
+
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
|
63 |
+
do { \
|
64 |
+
TORCH_CHECK_DTYPE(__w, kInt); \
|
65 |
+
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
66 |
+
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
67 |
+
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
68 |
+
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
69 |
+
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
70 |
+
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
71 |
+
} while(0)
|
72 |
+
|
73 |
+
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
74 |
+
{
|
75 |
+
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
76 |
+
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
77 |
+
return groupsize;
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
// Tuning parameters
|
82 |
+
|
83 |
+
ExLlamaTuning tuningParams;
|
84 |
+
|
85 |
+
void set_tuning_params
|
86 |
+
(
|
87 |
+
int matmul_recons_thd,
|
88 |
+
bool matmul_fused_remap,
|
89 |
+
bool matmul_no_half2
|
90 |
+
)
|
91 |
+
{
|
92 |
+
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
93 |
+
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
94 |
+
tuningParams.matmul_no_half2 = matmul_no_half2;
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
// Release all unmanaged objects allocated by the extension
|
99 |
+
|
100 |
+
void cleanup()
|
101 |
+
{
|
102 |
+
cleanup_buffers_cuda();
|
103 |
+
g_q4_free_matrices();
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
// Prepare buffers for forward pass
|
108 |
+
|
109 |
+
void prepare_buffers
|
110 |
+
(
|
111 |
+
torch::Device device,
|
112 |
+
torch::Tensor temp_state,
|
113 |
+
torch::Tensor temp_dq
|
114 |
+
)
|
115 |
+
{
|
116 |
+
int device_index = device.index();
|
117 |
+
TORCH_CHECK_DEVICE_INDEX(device_index);
|
118 |
+
const at::cuda::OptionalCUDAGuard device_guard(device);
|
119 |
+
const long max_int = std::numeric_limits<int>::max();
|
120 |
+
|
121 |
+
prepare_buffers_cuda
|
122 |
+
(
|
123 |
+
device_index,
|
124 |
+
// buffer size used for sanity checks
|
125 |
+
std::clamp((long)temp_state.numel(), (long)0, max_int),
|
126 |
+
(half*) temp_state.data_ptr(),
|
127 |
+
(half*) temp_dq.data_ptr()
|
128 |
+
);
|
129 |
+
}
|
130 |
+
|
131 |
+
|
132 |
+
// Create Q4Matrix, return handle
|
133 |
+
|
134 |
+
uintptr_t make_q4
|
135 |
+
(
|
136 |
+
torch::Tensor qweight,
|
137 |
+
torch::Tensor qzeros,
|
138 |
+
torch::Tensor scales,
|
139 |
+
torch::Tensor g_idx,
|
140 |
+
int device
|
141 |
+
)
|
142 |
+
{
|
143 |
+
TORCH_CHECK_DTYPE(qweight, kInt);
|
144 |
+
TORCH_CHECK_DTYPE(qzeros, kInt);
|
145 |
+
TORCH_CHECK_DTYPE(scales, kHalf);
|
146 |
+
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
147 |
+
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
148 |
+
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
149 |
+
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
150 |
+
|
151 |
+
int width = qweight.size(1);
|
152 |
+
int height = qweight.size(0) * 8;
|
153 |
+
int groups = qzeros.size(0);
|
154 |
+
|
155 |
+
Q4Matrix* m = new Q4Matrix
|
156 |
+
(
|
157 |
+
height,
|
158 |
+
width,
|
159 |
+
groups,
|
160 |
+
|
161 |
+
(uint32_t*) qweight.data_ptr(),
|
162 |
+
(uint32_t*) qzeros.data_ptr(),
|
163 |
+
(half*) scales.data_ptr(),
|
164 |
+
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
|
165 |
+
|
166 |
+
device
|
167 |
+
);
|
168 |
+
|
169 |
+
g_q4_keep_matrix(m);
|
170 |
+
return reinterpret_cast<uintptr_t> (m);
|
171 |
+
}
|
172 |
+
|
173 |
+
|
174 |
+
// Matmul half @ quant -> half
|
175 |
+
|
176 |
+
void q4_matmul
|
177 |
+
(
|
178 |
+
torch::Tensor x,
|
179 |
+
uintptr_t w,
|
180 |
+
torch::Tensor out
|
181 |
+
)
|
182 |
+
{
|
183 |
+
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
|
184 |
+
|
185 |
+
TORCH_CHECK_DTYPE(x, kHalf);
|
186 |
+
TORCH_CHECK_DTYPE(out, kHalf);
|
187 |
+
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
188 |
+
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
189 |
+
|
190 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
191 |
+
|
192 |
+
int x_height = x.size(0);
|
193 |
+
|
194 |
+
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
195 |
+
{
|
196 |
+
q4_matmul_cuda
|
197 |
+
(
|
198 |
+
&tuningParams,
|
199 |
+
(half*) x.data_ptr(),
|
200 |
+
x_height,
|
201 |
+
wm,
|
202 |
+
(half*) out.data_ptr()
|
203 |
+
);
|
204 |
+
}
|
205 |
+
else
|
206 |
+
{
|
207 |
+
q4_matmul_recons_cuda
|
208 |
+
(
|
209 |
+
&tuningParams,
|
210 |
+
(half*) x.data_ptr(),
|
211 |
+
x_height,
|
212 |
+
wm,
|
213 |
+
(half*) out.data_ptr(),
|
214 |
+
at::cuda::getCurrentCUDABlasHandle()
|
215 |
+
);
|
216 |
+
}
|
217 |
+
}
|
218 |
+
|
219 |
+
|
220 |
+
// Remap columns in half tensor
|
221 |
+
|
222 |
+
void column_remap
|
223 |
+
(
|
224 |
+
torch::Tensor x,
|
225 |
+
torch::Tensor x_new,
|
226 |
+
torch::Tensor x_map
|
227 |
+
)
|
228 |
+
{
|
229 |
+
TORCH_CHECK_DTYPE(x, kHalf);
|
230 |
+
TORCH_CHECK_DTYPE(x_new, kHalf);
|
231 |
+
TORCH_CHECK_DTYPE(x_map, kInt);
|
232 |
+
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
233 |
+
|
234 |
+
int height = x.size(0);
|
235 |
+
int width = x.size(1);
|
236 |
+
|
237 |
+
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
|
238 |
+
|
239 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
240 |
+
|
241 |
+
column_remap_cuda
|
242 |
+
(
|
243 |
+
(half*) x.data_ptr(),
|
244 |
+
(half*) x_new.data_ptr(),
|
245 |
+
height,
|
246 |
+
width,
|
247 |
+
(uint32_t*) x_map.data_ptr()
|
248 |
+
);
|
249 |
+
}
|
250 |
+
|
251 |
+
|
252 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
253 |
+
{
|
254 |
+
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
255 |
+
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
256 |
+
m.def("cleanup", &cleanup, "cleanup");
|
257 |
+
m.def("make_q4", &make_q4, "make_q4");
|
258 |
+
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
259 |
+
m.def("cleanup_buffers_cuda", &cleanup_buffers_cuda, "cleanup_buffers_cuda");
|
260 |
+
}
|
AutoAWQ_kernels/awq_ext/exllama/hip_compat.cuh
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#ifndef _hip_compat_cuh
|
4 |
+
#define _hip_compat_cuh
|
5 |
+
|
6 |
+
// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
|
7 |
+
__device__ __forceinline__ __half __compat_hrcp(__half x) {
|
8 |
+
return __half_raw{
|
9 |
+
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
|
10 |
+
}
|
11 |
+
|
12 |
+
// ROCm 6.0 compatible from: /opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_fp16.h:1708
|
13 |
+
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
14 |
+
return _Float16_2{_Float16_2{static_cast<_Float16>(1.0f), static_cast<_Float16>(1.0f)} / x.data};
|
15 |
+
}
|
16 |
+
|
17 |
+
#define hrcp __compat_hrcp
|
18 |
+
#define h2rcp __compat_h2rcp
|
19 |
+
|
20 |
+
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
|
21 |
+
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
22 |
+
hipblasOperation_t transA,
|
23 |
+
hipblasOperation_t transB,
|
24 |
+
int m,
|
25 |
+
int n,
|
26 |
+
int k,
|
27 |
+
const half* alpha,
|
28 |
+
const half* AP,
|
29 |
+
int lda,
|
30 |
+
const half* BP,
|
31 |
+
int ldb,
|
32 |
+
const half* beta,
|
33 |
+
half* CP,
|
34 |
+
int ldc) {
|
35 |
+
return hipblasHgemm(handle, transA, transB, m, n, k,
|
36 |
+
reinterpret_cast<const hipblasHalf *>(alpha),
|
37 |
+
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
38 |
+
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
39 |
+
reinterpret_cast<const hipblasHalf *>(beta),
|
40 |
+
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
41 |
+
}
|
42 |
+
#define hipblasHgemm __compat_hipblasHgemm
|
43 |
+
|
44 |
+
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
45 |
+
#define rocblas_handle hipblasHandle_t
|
46 |
+
#define rocblas_operation_none HIPBLAS_OP_N
|
47 |
+
#define rocblas_get_stream hipblasGetStream
|
48 |
+
#define rocblas_set_stream hipblasSetStream
|
49 |
+
#define rocblas_hgemm __compat_hipblasHgemm
|
50 |
+
|
51 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllama/matrix.cuh
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#ifndef _matrix_cuh
|
4 |
+
#define _matrix_cuh
|
5 |
+
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
#include <cuda_fp16.h>
|
8 |
+
|
9 |
+
class MatrixView_half
|
10 |
+
{
|
11 |
+
public:
|
12 |
+
const half* data;
|
13 |
+
const int height;
|
14 |
+
const int width;
|
15 |
+
|
16 |
+
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
17 |
+
: data(data), height(height), width(width)
|
18 |
+
{ }
|
19 |
+
|
20 |
+
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
21 |
+
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
22 |
+
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
23 |
+
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
24 |
+
};
|
25 |
+
|
26 |
+
class MatrixView_half_rw
|
27 |
+
{
|
28 |
+
public:
|
29 |
+
half* data;
|
30 |
+
const int height;
|
31 |
+
const int width;
|
32 |
+
|
33 |
+
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
34 |
+
: data(data), height(height), width(width)
|
35 |
+
{ }
|
36 |
+
|
37 |
+
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
38 |
+
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
39 |
+
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
40 |
+
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
41 |
+
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
42 |
+
};
|
43 |
+
|
44 |
+
class MatrixView_q4_row
|
45 |
+
{
|
46 |
+
public:
|
47 |
+
const uint32_t* data;
|
48 |
+
const int height;
|
49 |
+
const int width;
|
50 |
+
|
51 |
+
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
52 |
+
: data(data), height(height), width(width)
|
53 |
+
{ }
|
54 |
+
|
55 |
+
__device__ __forceinline__ int item(int row, int column) const
|
56 |
+
{
|
57 |
+
int shift = (column & 0x07) * 4;
|
58 |
+
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
59 |
+
}
|
60 |
+
};
|
61 |
+
|
62 |
+
class MatrixView_q4_column
|
63 |
+
{
|
64 |
+
public:
|
65 |
+
const uint32_t* data;
|
66 |
+
const int height;
|
67 |
+
const int width;
|
68 |
+
|
69 |
+
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
70 |
+
: data(data), height(height), width(width)
|
71 |
+
{ }
|
72 |
+
|
73 |
+
__device__ __forceinline__ int item(int row, int column) const
|
74 |
+
{
|
75 |
+
int shift = (row & 0x07) * 4;
|
76 |
+
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
77 |
+
}
|
78 |
+
|
79 |
+
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
80 |
+
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
81 |
+
};
|
82 |
+
|
83 |
+
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
|
84 |
+
|
85 |
+
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
|
86 |
+
|
87 |
+
__device__ __forceinline__ half2 dot_product_8
|
88 |
+
(
|
89 |
+
const half2 acc,
|
90 |
+
MatrixView_half& h_,
|
91 |
+
const int h_row,
|
92 |
+
const int h_column, // divisible by 8
|
93 |
+
MatrixView_q4_column& v_,
|
94 |
+
const int v_row, // divisible by 8
|
95 |
+
const int v_column,
|
96 |
+
const half2 v_scale_2,
|
97 |
+
const uint32_t v_zero, // + 1 (!!)
|
98 |
+
const int count
|
99 |
+
)
|
100 |
+
{
|
101 |
+
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
102 |
+
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
103 |
+
half2 result = acc;
|
104 |
+
|
105 |
+
for (int i = 0; i < count; i++)
|
106 |
+
{
|
107 |
+
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
108 |
+
|
109 |
+
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
110 |
+
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
111 |
+
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
112 |
+
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
113 |
+
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
114 |
+
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
115 |
+
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
116 |
+
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
117 |
+
|
118 |
+
half2 v_01 = __halves2half2(v_0, v_1);
|
119 |
+
half2 v_23 = __halves2half2(v_2, v_3);
|
120 |
+
half2 v_45 = __halves2half2(v_4, v_5);
|
121 |
+
half2 v_67 = __halves2half2(v_6, v_7);
|
122 |
+
|
123 |
+
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
|
124 |
+
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
|
125 |
+
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
|
126 |
+
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
|
127 |
+
|
128 |
+
half2 tmp = __hmul2(*h_ptr++, v_01);
|
129 |
+
tmp = __hfma2(*h_ptr++, v_23, tmp);
|
130 |
+
tmp = __hfma2(*h_ptr++, v_45, tmp);
|
131 |
+
tmp = __hfma2(*h_ptr++, v_67, tmp);
|
132 |
+
result = __hfma2(v_scale_2, tmp, result);
|
133 |
+
}
|
134 |
+
|
135 |
+
return result;
|
136 |
+
}
|
137 |
+
|
138 |
+
__device__ __forceinline__ half dot_product_8_h
|
139 |
+
(
|
140 |
+
const half acc,
|
141 |
+
MatrixView_half& h_,
|
142 |
+
const int h_row,
|
143 |
+
const int h_column, // divisible by 8
|
144 |
+
MatrixView_q4_column& v_,
|
145 |
+
const int v_row, // divisible by 8
|
146 |
+
const int v_column,
|
147 |
+
const half v_scale,
|
148 |
+
const uint32_t v_zero, // + 1 (!!)
|
149 |
+
const int count
|
150 |
+
)
|
151 |
+
{
|
152 |
+
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
153 |
+
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
154 |
+
half result = acc;
|
155 |
+
|
156 |
+
for (int i = 0; i < count; i++)
|
157 |
+
{
|
158 |
+
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
159 |
+
|
160 |
+
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
161 |
+
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
162 |
+
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
163 |
+
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
164 |
+
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
165 |
+
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
166 |
+
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
167 |
+
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
168 |
+
|
169 |
+
half tmp = __hmul(*h_ptr++, v_0);
|
170 |
+
tmp = __hfma(*h_ptr++, v_1, tmp);
|
171 |
+
tmp = __hfma(*h_ptr++, v_2, tmp);
|
172 |
+
tmp = __hfma(*h_ptr++, v_3, tmp);
|
173 |
+
tmp = __hfma(*h_ptr++, v_4, tmp);
|
174 |
+
tmp = __hfma(*h_ptr++, v_5, tmp);
|
175 |
+
tmp = __hfma(*h_ptr++, v_6, tmp);
|
176 |
+
tmp = __hfma(*h_ptr++, v_7, tmp);
|
177 |
+
result = __hfma(v_scale, tmp, result);
|
178 |
+
}
|
179 |
+
|
180 |
+
return result;
|
181 |
+
}
|
182 |
+
|
183 |
+
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
184 |
+
|
185 |
+
__device__ __forceinline__ half2 dot_product_8_x_map
|
186 |
+
(
|
187 |
+
const half2 acc,
|
188 |
+
MatrixView_half& h_,
|
189 |
+
const int h_row,
|
190 |
+
const int h_column, // divisible by 8
|
191 |
+
MatrixView_q4_column& v_,
|
192 |
+
const int v_row, // divisible by 8
|
193 |
+
const int v_column,
|
194 |
+
const half2 v_scale_2,
|
195 |
+
const uint32_t v_zero, // + 1 (!!)
|
196 |
+
const int count,
|
197 |
+
const uint32_t* x_map
|
198 |
+
)
|
199 |
+
{
|
200 |
+
const half* h_ptr = h_.item_ptr(h_row, 0);
|
201 |
+
const uint32_t* x_map_ptr = x_map + h_column;
|
202 |
+
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
203 |
+
half2 result = acc;
|
204 |
+
|
205 |
+
for (int i = 0; i < count; i++)
|
206 |
+
{
|
207 |
+
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
208 |
+
|
209 |
+
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
210 |
+
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
211 |
+
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
212 |
+
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
213 |
+
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
214 |
+
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
215 |
+
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
216 |
+
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
217 |
+
|
218 |
+
half2 v_01 = __halves2half2(v_0, v_1);
|
219 |
+
half2 v_23 = __halves2half2(v_2, v_3);
|
220 |
+
half2 v_45 = __halves2half2(v_4, v_5);
|
221 |
+
half2 v_67 = __halves2half2(v_6, v_7);
|
222 |
+
|
223 |
+
half h_0 = h_ptr[*x_map_ptr++];
|
224 |
+
half h_1 = h_ptr[*x_map_ptr++];
|
225 |
+
half h_2 = h_ptr[*x_map_ptr++];
|
226 |
+
half h_3 = h_ptr[*x_map_ptr++];
|
227 |
+
half h_4 = h_ptr[*x_map_ptr++];
|
228 |
+
half h_5 = h_ptr[*x_map_ptr++];
|
229 |
+
half h_6 = h_ptr[*x_map_ptr++];
|
230 |
+
half h_7 = h_ptr[*x_map_ptr++];
|
231 |
+
|
232 |
+
half2 h_01 = __halves2half2(h_0, h_1);
|
233 |
+
half2 h_23 = __halves2half2(h_2, h_3);
|
234 |
+
half2 h_45 = __halves2half2(h_4, h_5);
|
235 |
+
half2 h_67 = __halves2half2(h_6, h_7);
|
236 |
+
|
237 |
+
half2 tmp = __hmul2(h_01, v_01);
|
238 |
+
tmp = __hfma2(h_23, v_23, tmp);
|
239 |
+
tmp = __hfma2(h_45, v_45, tmp);
|
240 |
+
tmp = __hfma2(h_67, v_67, tmp);
|
241 |
+
result = __hfma2(v_scale_2, tmp, result);
|
242 |
+
}
|
243 |
+
|
244 |
+
return result;
|
245 |
+
}
|
246 |
+
|
247 |
+
__device__ __forceinline__ half dot_product_8_x_map_h
|
248 |
+
(
|
249 |
+
const half acc,
|
250 |
+
MatrixView_half& h_,
|
251 |
+
const int h_row,
|
252 |
+
const int h_column, // divisible by 8
|
253 |
+
MatrixView_q4_column& v_,
|
254 |
+
const int v_row, // divisible by 8
|
255 |
+
const int v_column,
|
256 |
+
const half v_scale,
|
257 |
+
const uint32_t v_zero, // + 1 (!!)
|
258 |
+
const int count,
|
259 |
+
const uint32_t* x_map
|
260 |
+
)
|
261 |
+
{
|
262 |
+
const half* h_ptr = h_.item_ptr(h_row, 0);
|
263 |
+
const uint32_t* x_map_ptr = x_map + h_column;
|
264 |
+
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
265 |
+
half result = acc;
|
266 |
+
|
267 |
+
for (int i = 0; i < count; i++)
|
268 |
+
{
|
269 |
+
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
270 |
+
|
271 |
+
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
272 |
+
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
273 |
+
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
274 |
+
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
275 |
+
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
276 |
+
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
277 |
+
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
278 |
+
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
279 |
+
|
280 |
+
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
281 |
+
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
282 |
+
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
283 |
+
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
284 |
+
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
285 |
+
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
286 |
+
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
287 |
+
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
288 |
+
result = __hfma(v_scale, tmp, result);
|
289 |
+
}
|
290 |
+
|
291 |
+
return result;
|
292 |
+
}
|
293 |
+
|
294 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllama/tuning.h
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#ifndef _tuning_h
|
4 |
+
#define _tuning_h
|
5 |
+
|
6 |
+
struct ExLlamaTuning
|
7 |
+
{
|
8 |
+
int matmul_recons_thd;
|
9 |
+
bool matmul_fused_remap;
|
10 |
+
bool matmul_no_half2;
|
11 |
+
};
|
12 |
+
|
13 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllama/util.cuh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
2 |
+
|
3 |
+
#ifndef _util_cuh
|
4 |
+
#define _util_cuh
|
5 |
+
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
#include <cuda_fp16.h>
|
8 |
+
#include <cstdint>
|
9 |
+
#include <cstdio>
|
10 |
+
|
11 |
+
#if defined(USE_ROCM)
|
12 |
+
#define cudaUnspecified hipErrorUnknown
|
13 |
+
#else
|
14 |
+
#define cudaUnspecified cudaErrorApiFailureBase
|
15 |
+
#endif
|
16 |
+
|
17 |
+
// React to failure on return code != cudaSuccess
|
18 |
+
|
19 |
+
#define _cuda_check(fn) \
|
20 |
+
do { \
|
21 |
+
{_cuda_err = fn;} \
|
22 |
+
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
|
23 |
+
} while(false)
|
24 |
+
|
25 |
+
// React to failure on return code == 0
|
26 |
+
|
27 |
+
#define _alloc_check(fn) \
|
28 |
+
do { \
|
29 |
+
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
|
30 |
+
else _cuda_err = cudaSuccess; \
|
31 |
+
} while(false)
|
32 |
+
|
33 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/config.h
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _config_h
|
2 |
+
#define _config_h
|
3 |
+
|
4 |
+
#define MAX_Q_GEMM_ROWS 50
|
5 |
+
|
6 |
+
#define QMODE_2BIT 1
|
7 |
+
#define QMODE_3BIT 1
|
8 |
+
#define QMODE_4BIT 1
|
9 |
+
#define QMODE_5BIT 1
|
10 |
+
#define QMODE_6BIT 0
|
11 |
+
#define QMODE_8BIT 0
|
12 |
+
|
13 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cpp/util.h
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _util_h
|
2 |
+
#define _util_h
|
3 |
+
|
4 |
+
#define DBGS(__x) printf("%s\n", __x)
|
5 |
+
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
6 |
+
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
7 |
+
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
8 |
+
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
9 |
+
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
10 |
+
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
11 |
+
|
12 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/compat.cuh
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _compat_cuh
|
2 |
+
#define _compat_cuh
|
3 |
+
|
4 |
+
// atomicAdd for half types, to support CC < 7.x
|
5 |
+
|
6 |
+
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
7 |
+
{
|
8 |
+
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
9 |
+
unsigned int old = *address_as_ui;
|
10 |
+
unsigned int assumed;
|
11 |
+
|
12 |
+
do
|
13 |
+
{
|
14 |
+
assumed = old;
|
15 |
+
__half_raw hsum;
|
16 |
+
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
17 |
+
half tmpres = __hadd(hsum, val);
|
18 |
+
hsum = __half_raw(tmpres);
|
19 |
+
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
20 |
+
old = atomicCAS(address_as_ui, assumed, old);
|
21 |
+
}
|
22 |
+
while (assumed != old);
|
23 |
+
}
|
24 |
+
|
25 |
+
// atomicAdd for half2 types
|
26 |
+
|
27 |
+
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
28 |
+
{
|
29 |
+
unsigned int* address_as_ui = (unsigned int*)address;
|
30 |
+
unsigned int old = *address_as_ui;
|
31 |
+
unsigned int assumed;
|
32 |
+
do
|
33 |
+
{
|
34 |
+
assumed = old;
|
35 |
+
half2 old_val = *((half2*)&old);
|
36 |
+
half2 new_val = __hadd2(old_val, val);
|
37 |
+
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
38 |
+
}
|
39 |
+
while (assumed != old);
|
40 |
+
}
|
41 |
+
|
42 |
+
//
|
43 |
+
|
44 |
+
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
45 |
+
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
46 |
+
|
47 |
+
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
48 |
+
|
49 |
+
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
50 |
+
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
51 |
+
#endif
|
52 |
+
|
53 |
+
#endif
|
54 |
+
#endif
|
55 |
+
|
56 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/compat_gemm.cuh
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _compat_gemm_cuh
|
2 |
+
#define _compat_gemm_cuh
|
3 |
+
|
4 |
+
#if defined(USE_ROCM)
|
5 |
+
|
6 |
+
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
|
7 |
+
// for symbols as hipblasHalf.
|
8 |
+
#include <hipblas/hipblas.h>
|
9 |
+
|
10 |
+
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
11 |
+
hipblasOperation_t transA,
|
12 |
+
hipblasOperation_t transB,
|
13 |
+
int m,
|
14 |
+
int n,
|
15 |
+
int k,
|
16 |
+
const half* alpha,
|
17 |
+
const half* AP,
|
18 |
+
int lda,
|
19 |
+
const half* BP,
|
20 |
+
int ldb,
|
21 |
+
const half* beta,
|
22 |
+
half* CP,
|
23 |
+
int ldc) {
|
24 |
+
return hipblasHgemm(handle, transA, transB, m, n, k,
|
25 |
+
reinterpret_cast<const hipblasHalf *>(alpha),
|
26 |
+
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
27 |
+
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
28 |
+
reinterpret_cast<const hipblasHalf *>(beta),
|
29 |
+
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
30 |
+
}
|
31 |
+
#define hipblasHgemm __compat_hipblasHgemm
|
32 |
+
|
33 |
+
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
34 |
+
#define rocblas_operation_none HIPBLAS_OP_N
|
35 |
+
#define rocblas_hgemm __compat_hipblasHgemm
|
36 |
+
#endif
|
37 |
+
|
38 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/matrix_view.cuh
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _matrix_view_cuh
|
2 |
+
#define _matrix_view_cuh
|
3 |
+
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
#include <cuda_fp16.h>
|
6 |
+
|
7 |
+
#include "quant/qdq_util.cuh"
|
8 |
+
|
9 |
+
class MatrixView_half
|
10 |
+
{
|
11 |
+
public:
|
12 |
+
const half* data;
|
13 |
+
const int height;
|
14 |
+
const int width;
|
15 |
+
|
16 |
+
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
17 |
+
: data(data), height(height), width(width)
|
18 |
+
{ }
|
19 |
+
|
20 |
+
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
21 |
+
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
22 |
+
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
23 |
+
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
24 |
+
|
25 |
+
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
26 |
+
{
|
27 |
+
half2* ptr = (half2*) item_ptr(row, column);
|
28 |
+
half2 i01 = ptr[0];
|
29 |
+
half2 i23 = ptr[1];
|
30 |
+
items[0] = __low2half(i01);
|
31 |
+
items[1] = __high2half(i01);
|
32 |
+
items[2] = __low2half(i23);
|
33 |
+
items[3] = __high2half(i23);
|
34 |
+
}
|
35 |
+
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
36 |
+
{
|
37 |
+
half2* ptr = (half2*)item_ptr(row, column);
|
38 |
+
half2 i01 = ptr[0];
|
39 |
+
half2 i23 = ptr[1];
|
40 |
+
items[0] = __half2float(__low2half(i01));
|
41 |
+
items[1] = __half2float(__high2half(i01));
|
42 |
+
items[2] = __half2float(__low2half(i23));
|
43 |
+
items[3] = __half2float(__high2half(i23));
|
44 |
+
}
|
45 |
+
|
46 |
+
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
47 |
+
{
|
48 |
+
half2* ptr = (half2*)item_ptr(row, column);
|
49 |
+
half2 i01 = ptr[0];
|
50 |
+
half2 i23 = ptr[1];
|
51 |
+
items[0] = __half2half2(__low2half(i01));
|
52 |
+
items[1] = __half2half2(__high2half(i01));
|
53 |
+
items[2] = __half2half2(__low2half(i23));
|
54 |
+
items[3] = __half2half2(__high2half(i23));
|
55 |
+
}
|
56 |
+
};
|
57 |
+
|
58 |
+
class MatrixView_half_rw
|
59 |
+
{
|
60 |
+
public:
|
61 |
+
half* data;
|
62 |
+
const int height;
|
63 |
+
const int width;
|
64 |
+
|
65 |
+
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
66 |
+
: data(data), height(height), width(width)
|
67 |
+
{ }
|
68 |
+
|
69 |
+
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
70 |
+
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
71 |
+
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
72 |
+
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
73 |
+
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
74 |
+
|
75 |
+
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
76 |
+
{
|
77 |
+
half2 v01 = __halves2half2(v0, v1);
|
78 |
+
half2 v23 = __halves2half2(v2, v3);
|
79 |
+
half2* ptr = (half2*) item_ptr(row, column);
|
80 |
+
ptr[0] = v01;
|
81 |
+
ptr[1] = v23;
|
82 |
+
}
|
83 |
+
};
|
84 |
+
|
85 |
+
class MatrixView_q4_row
|
86 |
+
{
|
87 |
+
public:
|
88 |
+
const uint32_t* data;
|
89 |
+
const int height;
|
90 |
+
const int width;
|
91 |
+
|
92 |
+
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
93 |
+
: data(data), height(height), width(width)
|
94 |
+
{ }
|
95 |
+
|
96 |
+
__device__ __forceinline__ int item(int row, int column) const
|
97 |
+
{
|
98 |
+
int shift = (column & 0x07) * 4;
|
99 |
+
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
100 |
+
}
|
101 |
+
|
102 |
+
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
103 |
+
{
|
104 |
+
int shift = (column & 0x07) * 4;
|
105 |
+
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
106 |
+
items[0] = d & 0x0f;
|
107 |
+
items[1] = (d >> 4) & 0x0f;
|
108 |
+
}
|
109 |
+
|
110 |
+
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
111 |
+
{
|
112 |
+
int shift = (column & 0x07) * 4;
|
113 |
+
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
114 |
+
items[0] = d & 0x0f;
|
115 |
+
items[1] = (d >> 4) & 0x0f;
|
116 |
+
items[2] = (d >> 8) & 0x0f;
|
117 |
+
items[3] = (d >> 12) & 0x0f;
|
118 |
+
}
|
119 |
+
};
|
120 |
+
|
121 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm.cu
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "q_gemm.cuh"
|
2 |
+
#include "util.cuh"
|
3 |
+
#include "matrix_view.cuh"
|
4 |
+
#include "../config.h"
|
5 |
+
|
6 |
+
#include "quant/qdq_2.cuh"
|
7 |
+
#include "quant/qdq_3.cuh"
|
8 |
+
#include "quant/qdq_4.cuh"
|
9 |
+
#include "quant/qdq_5.cuh"
|
10 |
+
#include "quant/qdq_6.cuh"
|
11 |
+
#include "quant/qdq_8.cuh"
|
12 |
+
|
13 |
+
#define BLOCK_KN_SIZE 128
|
14 |
+
#define BLOCK_M_SIZE_MAX 8
|
15 |
+
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
16 |
+
#define CLEAR_N_SIZE 256
|
17 |
+
|
18 |
+
#include "q_gemm_kernel.cuh"
|
19 |
+
#include "q_gemm_kernel_gptq.cuh"
|
20 |
+
|
21 |
+
#include "compat_gemm.cuh"
|
22 |
+
|
23 |
+
void gemm_half_q_half_cuda_part
|
24 |
+
(
|
25 |
+
const half* a,
|
26 |
+
QMatrix* b,
|
27 |
+
half* c,
|
28 |
+
int size_m,
|
29 |
+
int size_n,
|
30 |
+
int size_k,
|
31 |
+
int m_count,
|
32 |
+
bool clear
|
33 |
+
)
|
34 |
+
{
|
35 |
+
if (!b->is_gptq)
|
36 |
+
{
|
37 |
+
dim3 blockDim, gridDim;
|
38 |
+
blockDim.x = BLOCK_KN_SIZE;
|
39 |
+
blockDim.y = 1;
|
40 |
+
blockDim.z = 1;
|
41 |
+
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
42 |
+
gridDim.y = DIVIDE(size_m, m_count);
|
43 |
+
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
44 |
+
|
45 |
+
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
|
46 |
+
|
47 |
+
kernel<<<gridDim, blockDim>>>
|
48 |
+
(
|
49 |
+
a,
|
50 |
+
b->cuda_q_weight,
|
51 |
+
b->cuda_q_scale,
|
52 |
+
b->cuda_q_scale_max,
|
53 |
+
c,
|
54 |
+
size_m,
|
55 |
+
size_n,
|
56 |
+
size_k,
|
57 |
+
b->groups,
|
58 |
+
b->groupsize,
|
59 |
+
b->cuda_q_perm,
|
60 |
+
b->rows_8,
|
61 |
+
b->rows_6,
|
62 |
+
b->rows_5,
|
63 |
+
b->rows_4,
|
64 |
+
b->rows_3,
|
65 |
+
b->rows_2,
|
66 |
+
clear
|
67 |
+
);
|
68 |
+
}
|
69 |
+
else
|
70 |
+
{
|
71 |
+
dim3 blockDim, gridDim;
|
72 |
+
blockDim.x = BLOCK_KN_SIZE;
|
73 |
+
blockDim.y = 1;
|
74 |
+
blockDim.z = 1;
|
75 |
+
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
76 |
+
gridDim.y = DIVIDE(size_m, m_count);
|
77 |
+
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
78 |
+
|
79 |
+
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
80 |
+
|
81 |
+
// DBGX((uint64_t) b->cuda_q_perm);
|
82 |
+
// DBGI(b->rows_4);
|
83 |
+
// DBGI(b->height);
|
84 |
+
|
85 |
+
kernel<<<gridDim, blockDim>>>
|
86 |
+
(
|
87 |
+
a,
|
88 |
+
b->cuda_q_weight,
|
89 |
+
b->cuda_gptq_qzeros,
|
90 |
+
b->cuda_gptq_scales,
|
91 |
+
c,
|
92 |
+
size_m,
|
93 |
+
size_n,
|
94 |
+
size_k,
|
95 |
+
b->groups,
|
96 |
+
b->groupsize,
|
97 |
+
b->cuda_q_perm,
|
98 |
+
b->rows_4,
|
99 |
+
clear
|
100 |
+
);
|
101 |
+
}
|
102 |
+
}
|
103 |
+
|
104 |
+
void gemm_half_q_half_cuda
|
105 |
+
(
|
106 |
+
cublasHandle_t cublas_handle,
|
107 |
+
const half* a,
|
108 |
+
QMatrix* b,
|
109 |
+
half* c,
|
110 |
+
int size_m,
|
111 |
+
int size_n,
|
112 |
+
int size_k,
|
113 |
+
bool clear,
|
114 |
+
half* temp_dq,
|
115 |
+
bool force_cuda
|
116 |
+
)
|
117 |
+
{
|
118 |
+
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
119 |
+
{
|
120 |
+
//printf("cublas\n");
|
121 |
+
|
122 |
+
// Reconstruct FP16 matrix, then cuBLAS
|
123 |
+
|
124 |
+
if (!temp_dq) temp_dq = b->temp_dq;
|
125 |
+
b->reconstruct(temp_dq);
|
126 |
+
|
127 |
+
//cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
|
128 |
+
|
129 |
+
const half alpha = __float2half(1.0f);
|
130 |
+
const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
|
131 |
+
cublasHgemm(cublas_handle,
|
132 |
+
CUBLAS_OP_N,
|
133 |
+
CUBLAS_OP_N,
|
134 |
+
size_n, size_m, size_k,
|
135 |
+
&alpha, temp_dq, size_n,
|
136 |
+
a, size_k,
|
137 |
+
&beta, c, size_n);
|
138 |
+
|
139 |
+
//const float alpha = 1.0f;
|
140 |
+
//const float beta = clear ? 0.0f : 1.0f;
|
141 |
+
//cublasSgemmEx(cublas_handle,
|
142 |
+
// CUBLAS_OP_N,
|
143 |
+
// CUBLAS_OP_N,
|
144 |
+
// size_n, size_m, size_k,
|
145 |
+
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
146 |
+
// a, CUDA_R_16F, size_k,
|
147 |
+
// &beta, c, CUDA_R_16F, size_n);
|
148 |
+
|
149 |
+
//const float alpha = 1.0f;
|
150 |
+
//const float beta = clear ? 0.0f : 1.0f;
|
151 |
+
//cublasGemmEx(cublas_handle,
|
152 |
+
// CUBLAS_OP_N, CUBLAS_OP_N,
|
153 |
+
// size_n, size_m, size_k,
|
154 |
+
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
155 |
+
// a, CUDA_R_16F, size_k,
|
156 |
+
// &beta, c, CUDA_R_16F, size_n,
|
157 |
+
// CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
|
158 |
+
}
|
159 |
+
else
|
160 |
+
{
|
161 |
+
//printf("cuda\n");
|
162 |
+
|
163 |
+
// Quantized matmul
|
164 |
+
|
165 |
+
//if (clear) clear_tensor_cuda(c, size_m, size_n);
|
166 |
+
|
167 |
+
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
168 |
+
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
169 |
+
int last_chunk_size = size_m - last_chunk;
|
170 |
+
|
171 |
+
if (max_chunks)
|
172 |
+
{
|
173 |
+
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
|
174 |
+
}
|
175 |
+
|
176 |
+
if (last_chunk_size)
|
177 |
+
{
|
178 |
+
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
|
179 |
+
}
|
180 |
+
}
|
181 |
+
}
|
182 |
+
|
183 |
+
__global__ void clear_kernel
|
184 |
+
(
|
185 |
+
half* __restrict__ c,
|
186 |
+
const int size_m,
|
187 |
+
const int size_n
|
188 |
+
)
|
189 |
+
{
|
190 |
+
int m = blockIdx.y;
|
191 |
+
int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
|
192 |
+
if (n >= size_n) return;
|
193 |
+
int4* c_ptr = (int4*)(c + m * size_n + n);
|
194 |
+
*c_ptr = {};
|
195 |
+
}
|
196 |
+
|
197 |
+
void clear_tensor_cuda
|
198 |
+
(
|
199 |
+
half* c,
|
200 |
+
int size_m,
|
201 |
+
int size_n
|
202 |
+
)
|
203 |
+
{
|
204 |
+
return;
|
205 |
+
dim3 blockDim, gridDim;
|
206 |
+
blockDim.x = CLEAR_N_SIZE;
|
207 |
+
blockDim.y = 1;
|
208 |
+
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
209 |
+
gridDim.y = size_m;
|
210 |
+
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
211 |
+
}
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm.cuh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _q_gemm_cuh
|
2 |
+
#define _q_gemm_cuh
|
3 |
+
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
#include <cuda_fp16.h>
|
6 |
+
#include <cstdint>
|
7 |
+
#include <cstdio>
|
8 |
+
#include <ATen/cuda/CUDAContext.h>
|
9 |
+
|
10 |
+
#include "q_matrix.cuh"
|
11 |
+
|
12 |
+
void gemm_half_q_half_cuda
|
13 |
+
(
|
14 |
+
cublasHandle_t cublas_handle,
|
15 |
+
const half* a,
|
16 |
+
QMatrix* b,
|
17 |
+
half* c,
|
18 |
+
int size_m,
|
19 |
+
int size_n,
|
20 |
+
int size_k,
|
21 |
+
bool clear = false,
|
22 |
+
half* reconstruct = NULL,
|
23 |
+
bool force_cuda = false
|
24 |
+
);
|
25 |
+
|
26 |
+
void clear_tensor_cuda
|
27 |
+
(
|
28 |
+
half* c,
|
29 |
+
int size_m,
|
30 |
+
int size_n
|
31 |
+
);
|
32 |
+
|
33 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm_kernel.cuh
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "compat.cuh"
|
2 |
+
|
3 |
+
#include <cuda_runtime.h>
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
|
6 |
+
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
7 |
+
{
|
8 |
+
half2 result = {};
|
9 |
+
const half2* a2_ptr = (const half2*)a_ptr;
|
10 |
+
#pragma unroll
|
11 |
+
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
12 |
+
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
13 |
+
}
|
14 |
+
|
15 |
+
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
|
16 |
+
{
|
17 |
+
half2 result = {};
|
18 |
+
const half2* a2_ptr = (const half2*)a_ptr;
|
19 |
+
#pragma unroll
|
20 |
+
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
21 |
+
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
22 |
+
}
|
23 |
+
|
24 |
+
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
|
25 |
+
{
|
26 |
+
half2 result = {};
|
27 |
+
const half2* a2_ptr = (const half2*)a_ptr;
|
28 |
+
#pragma unroll
|
29 |
+
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
30 |
+
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
31 |
+
}
|
32 |
+
|
33 |
+
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
|
34 |
+
{
|
35 |
+
half2 result = {};
|
36 |
+
const half2* a2_ptr = (const half2*)a_ptr;
|
37 |
+
#pragma unroll
|
38 |
+
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
39 |
+
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
40 |
+
return fma(result_f, qs_f, g_result);
|
41 |
+
}
|
42 |
+
|
43 |
+
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
|
44 |
+
{
|
45 |
+
half2 result = {};
|
46 |
+
const half2* a2_ptr = (const half2*)a_ptr;
|
47 |
+
#pragma unroll
|
48 |
+
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
49 |
+
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
50 |
+
return fma(result_f, qs_f, g_result);
|
51 |
+
}
|
52 |
+
|
53 |
+
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
|
54 |
+
{
|
55 |
+
half2 result = {};
|
56 |
+
const half2* a2_ptr = (const half2*)a_ptr;
|
57 |
+
#pragma unroll
|
58 |
+
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
59 |
+
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
60 |
+
return fma(result_f, qs_f, g_result);
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
typedef void (*fp_gemm_half_q_half_kernel)
|
66 |
+
(
|
67 |
+
const half*,
|
68 |
+
const uint32_t*,
|
69 |
+
const uint32_t*,
|
70 |
+
const half*,
|
71 |
+
half*,
|
72 |
+
const int,
|
73 |
+
const int,
|
74 |
+
const int,
|
75 |
+
const int,
|
76 |
+
const int,
|
77 |
+
const uint16_t*,
|
78 |
+
const int,
|
79 |
+
const int,
|
80 |
+
const int,
|
81 |
+
const int,
|
82 |
+
const int,
|
83 |
+
const int,
|
84 |
+
const bool
|
85 |
+
);
|
86 |
+
|
87 |
+
template <bool first_block, int m_count>
|
88 |
+
__global__ void gemm_half_q_half_kernel
|
89 |
+
(
|
90 |
+
const half* __restrict__ a,
|
91 |
+
const uint32_t* __restrict__ b_q_weight,
|
92 |
+
const uint32_t* __restrict__ b_q_scale,
|
93 |
+
const half* __restrict__ b_q_scale_max,
|
94 |
+
half* __restrict__ c,
|
95 |
+
const int size_m,
|
96 |
+
const int size_n,
|
97 |
+
const int size_k,
|
98 |
+
const int groups,
|
99 |
+
const int groupsize,
|
100 |
+
const uint16_t* __restrict__ b_q_perm,
|
101 |
+
const int rows_8,
|
102 |
+
const int rows_6,
|
103 |
+
const int rows_5,
|
104 |
+
const int rows_4,
|
105 |
+
const int rows_3,
|
106 |
+
const int rows_2,
|
107 |
+
const bool clear
|
108 |
+
)
|
109 |
+
{
|
110 |
+
MatrixView_half a_(a, size_m, size_k);
|
111 |
+
MatrixView_half_rw c_(c, size_m, size_n);
|
112 |
+
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
113 |
+
|
114 |
+
int t = threadIdx.x;
|
115 |
+
|
116 |
+
// Block
|
117 |
+
|
118 |
+
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
119 |
+
int offset_m = blockIdx.y * m_count;
|
120 |
+
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
121 |
+
|
122 |
+
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
123 |
+
int end_m = min(offset_m + m_count, size_m);
|
124 |
+
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
125 |
+
int n = offset_n + t * 4;
|
126 |
+
|
127 |
+
// Preload block_a
|
128 |
+
|
129 |
+
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
130 |
+
|
131 |
+
if (offset_k + t < end_k)
|
132 |
+
{
|
133 |
+
for (int m = 0; m < m_count; ++m)
|
134 |
+
{
|
135 |
+
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
136 |
+
half* block_a_ptr = block_a[m];
|
137 |
+
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
138 |
+
block_a_ptr[t] = a0;
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
// Clear
|
143 |
+
|
144 |
+
if (n >= size_n) return;
|
145 |
+
|
146 |
+
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
147 |
+
{
|
148 |
+
for (int m = 0; m < m_count; m++)
|
149 |
+
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
|
150 |
+
}
|
151 |
+
|
152 |
+
__syncthreads();
|
153 |
+
|
154 |
+
// Find initial group
|
155 |
+
|
156 |
+
int group = offset_k / groupsize;
|
157 |
+
|
158 |
+
// Preload scales
|
159 |
+
|
160 |
+
float scales[MAX_GROUPS_IN_BLOCK][4];
|
161 |
+
|
162 |
+
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
163 |
+
for (int g = 0; g < groups_in_block; g++)
|
164 |
+
{
|
165 |
+
int qscales[4];
|
166 |
+
b_q_scale_.item4(qscales, group + g, n);
|
167 |
+
qscales[0]++;
|
168 |
+
qscales[1]++;
|
169 |
+
qscales[2]++;
|
170 |
+
qscales[3]++;
|
171 |
+
float maxscale = __half2float(b_q_scale_max[group + g]);
|
172 |
+
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
|
173 |
+
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
|
174 |
+
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
|
175 |
+
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
|
176 |
+
}
|
177 |
+
|
178 |
+
// a, b offset
|
179 |
+
|
180 |
+
int pre_rows_8 = min(rows_8, offset_k);
|
181 |
+
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
182 |
+
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
183 |
+
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
184 |
+
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
185 |
+
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
186 |
+
int qk = 0;
|
187 |
+
qk += pre_rows_8 / 32 * 8;
|
188 |
+
qk += pre_rows_6 / 32 * 6;
|
189 |
+
qk += pre_rows_5 / 32 * 5;
|
190 |
+
qk += pre_rows_4 / 32 * 4;
|
191 |
+
qk += pre_rows_3 / 32 * 3;
|
192 |
+
qk += pre_rows_2 / 32 * 2;
|
193 |
+
|
194 |
+
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
195 |
+
const half* a_ptr = &block_a[0][0];
|
196 |
+
int a_stride = BLOCK_KN_SIZE;
|
197 |
+
|
198 |
+
// Initial group
|
199 |
+
|
200 |
+
int scales_idx = 0;
|
201 |
+
float qs_f0 = scales[scales_idx][0];
|
202 |
+
float qs_f1 = scales[scales_idx][1];
|
203 |
+
float qs_f2 = scales[scales_idx][2];
|
204 |
+
float qs_f3 = scales[scales_idx][3];
|
205 |
+
int nextgroup = offset_k + groupsize;
|
206 |
+
|
207 |
+
// Column result
|
208 |
+
|
209 |
+
float block_c[m_count][4] = {};
|
210 |
+
|
211 |
+
// Dequantize groups
|
212 |
+
|
213 |
+
int k = offset_k;
|
214 |
+
|
215 |
+
while (k < rows_8 && k < end_k)
|
216 |
+
{
|
217 |
+
if (k == nextgroup)
|
218 |
+
{
|
219 |
+
group++;
|
220 |
+
scales_idx++;
|
221 |
+
qs_f0 = scales[scales_idx][0];
|
222 |
+
qs_f1 = scales[scales_idx][1];
|
223 |
+
qs_f2 = scales[scales_idx][2];
|
224 |
+
qs_f3 = scales[scales_idx][3];
|
225 |
+
nextgroup += groupsize;
|
226 |
+
}
|
227 |
+
|
228 |
+
#pragma unroll
|
229 |
+
for (int j = 0; j < 4; j++)
|
230 |
+
{
|
231 |
+
int4 load_int4[2];
|
232 |
+
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
233 |
+
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
234 |
+
|
235 |
+
half2 dq[4][4];
|
236 |
+
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
|
237 |
+
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
|
238 |
+
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
|
239 |
+
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
|
240 |
+
|
241 |
+
for (int m = 0; m < m_count; m++)
|
242 |
+
{
|
243 |
+
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
244 |
+
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
245 |
+
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
246 |
+
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
247 |
+
}
|
248 |
+
a_ptr += 8;
|
249 |
+
}
|
250 |
+
k += 32;
|
251 |
+
}
|
252 |
+
|
253 |
+
while (k < rows_6 && k < end_k)
|
254 |
+
{
|
255 |
+
if (k == nextgroup)
|
256 |
+
{
|
257 |
+
group++;
|
258 |
+
scales_idx++;
|
259 |
+
qs_f0 = scales[scales_idx][0];
|
260 |
+
qs_f1 = scales[scales_idx][1];
|
261 |
+
qs_f2 = scales[scales_idx][2];
|
262 |
+
qs_f3 = scales[scales_idx][3];
|
263 |
+
nextgroup += groupsize;
|
264 |
+
}
|
265 |
+
|
266 |
+
#pragma unroll
|
267 |
+
for (int j = 0; j < 2; j++)
|
268 |
+
{
|
269 |
+
int4 load_int4[3];
|
270 |
+
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
271 |
+
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
272 |
+
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
273 |
+
|
274 |
+
half2 dq[4][8];
|
275 |
+
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
276 |
+
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
277 |
+
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
278 |
+
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
279 |
+
|
280 |
+
for (int m = 0; m < m_count; m++)
|
281 |
+
{
|
282 |
+
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
283 |
+
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
284 |
+
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
285 |
+
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
286 |
+
}
|
287 |
+
a_ptr += 16;
|
288 |
+
}
|
289 |
+
k += 32;
|
290 |
+
}
|
291 |
+
|
292 |
+
while (k < rows_5 && k < end_k)
|
293 |
+
{
|
294 |
+
if (k == nextgroup)
|
295 |
+
{
|
296 |
+
group++;
|
297 |
+
scales_idx++;
|
298 |
+
qs_f0 = scales[scales_idx][0];
|
299 |
+
qs_f1 = scales[scales_idx][1];
|
300 |
+
qs_f2 = scales[scales_idx][2];
|
301 |
+
qs_f3 = scales[scales_idx][3];
|
302 |
+
nextgroup += groupsize;
|
303 |
+
}
|
304 |
+
|
305 |
+
#pragma unroll
|
306 |
+
for (int j = 0; j < 1; j++)
|
307 |
+
{
|
308 |
+
int4 load_int4[5];
|
309 |
+
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
310 |
+
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
311 |
+
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
312 |
+
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
|
313 |
+
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
|
314 |
+
|
315 |
+
half2 dq[4][16];
|
316 |
+
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
|
317 |
+
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
|
318 |
+
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
|
319 |
+
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
|
320 |
+
|
321 |
+
for (int m = 0; m < m_count; m++)
|
322 |
+
{
|
323 |
+
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
324 |
+
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
325 |
+
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
326 |
+
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
327 |
+
}
|
328 |
+
a_ptr += 32;
|
329 |
+
}
|
330 |
+
|
331 |
+
k += 32;
|
332 |
+
}
|
333 |
+
|
334 |
+
while (k < rows_4 && k < end_k)
|
335 |
+
{
|
336 |
+
if (k == nextgroup)
|
337 |
+
{
|
338 |
+
group++;
|
339 |
+
scales_idx++;
|
340 |
+
qs_f0 = scales[scales_idx][0];
|
341 |
+
qs_f1 = scales[scales_idx][1];
|
342 |
+
qs_f2 = scales[scales_idx][2];
|
343 |
+
qs_f3 = scales[scales_idx][3];
|
344 |
+
nextgroup += groupsize;
|
345 |
+
}
|
346 |
+
|
347 |
+
#pragma unroll
|
348 |
+
for (int j = 0; j < 4; j++)
|
349 |
+
{
|
350 |
+
int4 load_int4[1];
|
351 |
+
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
352 |
+
|
353 |
+
half2 dq[4][4];
|
354 |
+
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
|
355 |
+
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
|
356 |
+
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
|
357 |
+
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
|
358 |
+
|
359 |
+
for (int m = 0; m < m_count; m++)
|
360 |
+
{
|
361 |
+
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
362 |
+
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
363 |
+
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
364 |
+
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
365 |
+
}
|
366 |
+
a_ptr += 8;
|
367 |
+
}
|
368 |
+
k += 32;
|
369 |
+
}
|
370 |
+
|
371 |
+
while (k < rows_3 && k < end_k)
|
372 |
+
{
|
373 |
+
if (k == nextgroup)
|
374 |
+
{
|
375 |
+
group++;
|
376 |
+
scales_idx++;
|
377 |
+
qs_f0 = scales[scales_idx][0];
|
378 |
+
qs_f1 = scales[scales_idx][1];
|
379 |
+
qs_f2 = scales[scales_idx][2];
|
380 |
+
qs_f3 = scales[scales_idx][3];
|
381 |
+
nextgroup += groupsize;
|
382 |
+
}
|
383 |
+
|
384 |
+
#pragma unroll
|
385 |
+
for (int j = 0; j < 1; j++)
|
386 |
+
{
|
387 |
+
int4 load_int4[3];
|
388 |
+
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
389 |
+
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
390 |
+
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
391 |
+
|
392 |
+
half2 dq[4][16];
|
393 |
+
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
394 |
+
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
395 |
+
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
396 |
+
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
397 |
+
|
398 |
+
for (int m = 0; m < m_count; m++)
|
399 |
+
{
|
400 |
+
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
401 |
+
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
402 |
+
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
403 |
+
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
404 |
+
}
|
405 |
+
a_ptr += 32;
|
406 |
+
}
|
407 |
+
k += 32;
|
408 |
+
}
|
409 |
+
|
410 |
+
while (k < rows_2 && k < end_k)
|
411 |
+
{
|
412 |
+
if (k == nextgroup)
|
413 |
+
{
|
414 |
+
group++;
|
415 |
+
scales_idx++;
|
416 |
+
qs_f0 = scales[scales_idx][0];
|
417 |
+
qs_f1 = scales[scales_idx][1];
|
418 |
+
qs_f2 = scales[scales_idx][2];
|
419 |
+
qs_f3 = scales[scales_idx][3];
|
420 |
+
nextgroup += groupsize;
|
421 |
+
}
|
422 |
+
|
423 |
+
#pragma unroll
|
424 |
+
for (int j = 0; j < 2; j++)
|
425 |
+
{
|
426 |
+
int4 load_int4[1];
|
427 |
+
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
428 |
+
|
429 |
+
half2 dq[4][8];
|
430 |
+
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
|
431 |
+
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
|
432 |
+
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
|
433 |
+
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
|
434 |
+
|
435 |
+
for (int m = 0; m < m_count; m++)
|
436 |
+
{
|
437 |
+
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
438 |
+
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
439 |
+
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
440 |
+
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
441 |
+
}
|
442 |
+
|
443 |
+
a_ptr += 16;
|
444 |
+
}
|
445 |
+
k += 32;
|
446 |
+
}
|
447 |
+
|
448 |
+
// Accumulate column sums in c
|
449 |
+
|
450 |
+
for (int m = 0; m < m_count; m++)
|
451 |
+
{
|
452 |
+
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
453 |
+
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
454 |
+
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
455 |
+
atomicAdd(out , result01);
|
456 |
+
atomicAdd(out + 1, result23);
|
457 |
+
}
|
458 |
+
}
|
459 |
+
|
460 |
+
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
|
461 |
+
{
|
462 |
+
#if BLOCK_M_SIZE_MAX >= 1
|
463 |
+
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
|
464 |
+
#endif
|
465 |
+
#if BLOCK_M_SIZE_MAX >= 2
|
466 |
+
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
|
467 |
+
#endif
|
468 |
+
#if BLOCK_M_SIZE_MAX >= 3
|
469 |
+
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
|
470 |
+
#endif
|
471 |
+
#if BLOCK_M_SIZE_MAX >= 4
|
472 |
+
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
|
473 |
+
#endif
|
474 |
+
#if BLOCK_M_SIZE_MAX >= 5
|
475 |
+
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
|
476 |
+
#endif
|
477 |
+
#if BLOCK_M_SIZE_MAX >= 6
|
478 |
+
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
|
479 |
+
#endif
|
480 |
+
#if BLOCK_M_SIZE_MAX >= 7
|
481 |
+
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
|
482 |
+
#endif
|
483 |
+
#if BLOCK_M_SIZE_MAX >= 8
|
484 |
+
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
|
485 |
+
#endif
|
486 |
+
return NULL;
|
487 |
+
}
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_gemm_kernel_gptq.cuh
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "compat.cuh"
|
2 |
+
|
3 |
+
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
|
4 |
+
{
|
5 |
+
half2 result = {};
|
6 |
+
const half2* a2_ptr = (const half2*)a_ptr;
|
7 |
+
#pragma unroll
|
8 |
+
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
9 |
+
return __hadd2(result, g_result);
|
10 |
+
}
|
11 |
+
|
12 |
+
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
13 |
+
{
|
14 |
+
half2 result = {};
|
15 |
+
const half2* a2_ptr = (const half2*)a_ptr;
|
16 |
+
#pragma unroll
|
17 |
+
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
18 |
+
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
19 |
+
}
|
20 |
+
|
21 |
+
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
22 |
+
(
|
23 |
+
const half*,
|
24 |
+
const uint32_t*,
|
25 |
+
const uint32_t*,
|
26 |
+
const half*,
|
27 |
+
half*,
|
28 |
+
const int,
|
29 |
+
const int,
|
30 |
+
const int,
|
31 |
+
const int,
|
32 |
+
const int,
|
33 |
+
const uint16_t*,
|
34 |
+
const int,
|
35 |
+
const bool
|
36 |
+
);
|
37 |
+
|
38 |
+
template <bool first_block, int m_count>
|
39 |
+
__global__ void gemm_half_q_half_gptq_kernel
|
40 |
+
(
|
41 |
+
const half* __restrict__ a,
|
42 |
+
const uint32_t* __restrict__ b_q_weight,
|
43 |
+
const uint32_t* __restrict__ b_gptq_qzeros,
|
44 |
+
const half* __restrict__ b_gptq_scales,
|
45 |
+
half* __restrict__ c,
|
46 |
+
const int size_m,
|
47 |
+
const int size_n,
|
48 |
+
const int size_k,
|
49 |
+
const int groups,
|
50 |
+
const int groupsize,
|
51 |
+
const uint16_t* __restrict__ b_q_perm,
|
52 |
+
const int rows_4,
|
53 |
+
const bool clear
|
54 |
+
)
|
55 |
+
{
|
56 |
+
MatrixView_half a_(a, size_m, size_k);
|
57 |
+
MatrixView_half_rw c_(c, size_m, size_n);
|
58 |
+
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
59 |
+
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
60 |
+
|
61 |
+
int t = threadIdx.x;
|
62 |
+
|
63 |
+
// Block
|
64 |
+
|
65 |
+
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
66 |
+
int offset_m = blockIdx.y * m_count;
|
67 |
+
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
68 |
+
|
69 |
+
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
70 |
+
int end_m = min(offset_m + m_count, size_m);
|
71 |
+
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
72 |
+
|
73 |
+
int n = offset_n + t * 4;
|
74 |
+
|
75 |
+
// Preload block_a
|
76 |
+
|
77 |
+
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
78 |
+
|
79 |
+
if (offset_k + t < end_k)
|
80 |
+
{
|
81 |
+
for (int m = 0; m < m_count; ++m)
|
82 |
+
{
|
83 |
+
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
84 |
+
half* block_a_ptr = block_a[m];
|
85 |
+
|
86 |
+
half a0;
|
87 |
+
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
|
88 |
+
else a0 = a_ptr[offset_k + t];
|
89 |
+
block_a_ptr[t] = a0;
|
90 |
+
}
|
91 |
+
}
|
92 |
+
|
93 |
+
// Zero output
|
94 |
+
|
95 |
+
if (n >= size_n) return;
|
96 |
+
|
97 |
+
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
98 |
+
{
|
99 |
+
for (int m = 0; m < m_count; m++)
|
100 |
+
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
101 |
+
}
|
102 |
+
|
103 |
+
__syncthreads();
|
104 |
+
|
105 |
+
// Find initial group
|
106 |
+
|
107 |
+
int group = offset_k / groupsize;
|
108 |
+
int nextgroup = offset_k + groupsize;
|
109 |
+
|
110 |
+
// a, b offset
|
111 |
+
|
112 |
+
int qk = offset_k / (32 / 4);
|
113 |
+
|
114 |
+
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
115 |
+
const half* a_ptr = &block_a[0][0];
|
116 |
+
int a_stride = BLOCK_KN_SIZE;
|
117 |
+
|
118 |
+
// Initial group
|
119 |
+
|
120 |
+
int zeros[4];
|
121 |
+
float scales[4];
|
122 |
+
half2 z1z16[4][2];
|
123 |
+
half2 y1y16[4][2];
|
124 |
+
b_gptq_qzeros_.item4(zeros, group, n);
|
125 |
+
b_gptq_scales_.item4_f(scales, group, n);
|
126 |
+
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
|
127 |
+
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
|
128 |
+
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
|
129 |
+
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
|
130 |
+
|
131 |
+
// __syncthreads();
|
132 |
+
|
133 |
+
// Column result
|
134 |
+
|
135 |
+
float block_c[m_count][4] = {};
|
136 |
+
|
137 |
+
// Dequantize and multiply
|
138 |
+
|
139 |
+
int k = offset_k;
|
140 |
+
while (k < end_k)
|
141 |
+
{
|
142 |
+
if (k == nextgroup)
|
143 |
+
{
|
144 |
+
group++;
|
145 |
+
nextgroup += groupsize;
|
146 |
+
b_gptq_qzeros_.item4(zeros, group, n);
|
147 |
+
b_gptq_scales_.item4_f(scales, group, n);
|
148 |
+
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
|
149 |
+
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
|
150 |
+
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
|
151 |
+
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
|
152 |
+
}
|
153 |
+
|
154 |
+
#pragma unroll
|
155 |
+
for (int j = 0; j < 4; j++)
|
156 |
+
{
|
157 |
+
const int4* b_ptr4 = (int4*) b_ptr;
|
158 |
+
int4 load_int4 = *b_ptr4;
|
159 |
+
|
160 |
+
half2 dq[4][4];
|
161 |
+
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
162 |
+
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
163 |
+
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
164 |
+
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
165 |
+
|
166 |
+
#pragma unroll
|
167 |
+
for (int m = 0; m < m_count; m++)
|
168 |
+
{
|
169 |
+
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
170 |
+
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
171 |
+
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
172 |
+
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
173 |
+
}
|
174 |
+
|
175 |
+
b_ptr += size_n;
|
176 |
+
a_ptr += 8;
|
177 |
+
}
|
178 |
+
|
179 |
+
k += 32;
|
180 |
+
}
|
181 |
+
|
182 |
+
for (int m = 0; m < m_count; m++)
|
183 |
+
{
|
184 |
+
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
185 |
+
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
186 |
+
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
187 |
+
atomicAdd(out , result01);
|
188 |
+
atomicAdd(out + 1, result23);
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
193 |
+
{
|
194 |
+
#if BLOCK_M_SIZE_MAX >= 1
|
195 |
+
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
196 |
+
#endif
|
197 |
+
#if BLOCK_M_SIZE_MAX >= 2
|
198 |
+
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
199 |
+
#endif
|
200 |
+
#if BLOCK_M_SIZE_MAX >= 3
|
201 |
+
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
202 |
+
#endif
|
203 |
+
#if BLOCK_M_SIZE_MAX >= 4
|
204 |
+
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
205 |
+
#endif
|
206 |
+
#if BLOCK_M_SIZE_MAX >= 5
|
207 |
+
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
208 |
+
#endif
|
209 |
+
#if BLOCK_M_SIZE_MAX >= 6
|
210 |
+
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
211 |
+
#endif
|
212 |
+
#if BLOCK_M_SIZE_MAX >= 7
|
213 |
+
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
214 |
+
#endif
|
215 |
+
#if BLOCK_M_SIZE_MAX >= 8
|
216 |
+
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
217 |
+
#endif
|
218 |
+
return NULL;
|
219 |
+
}
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_matrix.cu
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "q_matrix.cuh"
|
2 |
+
#include "matrix_view.cuh"
|
3 |
+
#include "util.cuh"
|
4 |
+
|
5 |
+
#include "quant/qdq_2.cuh"
|
6 |
+
#include "quant/qdq_3.cuh"
|
7 |
+
#include "quant/qdq_4.cuh"
|
8 |
+
#include "quant/qdq_5.cuh"
|
9 |
+
#include "quant/qdq_6.cuh"
|
10 |
+
#include "quant/qdq_8.cuh"
|
11 |
+
|
12 |
+
#define BLOCK_KN_SIZE 128
|
13 |
+
|
14 |
+
#define THREADS_X 32
|
15 |
+
#define THREADS_Y 32
|
16 |
+
|
17 |
+
// Shuffle quantized data on load
|
18 |
+
|
19 |
+
__global__ void shuffle_kernel
|
20 |
+
(
|
21 |
+
uint32_t* __restrict__ b_q_weight,
|
22 |
+
const int size_k,
|
23 |
+
const int size_n,
|
24 |
+
const int rows_8,
|
25 |
+
const int rows_6,
|
26 |
+
const int rows_5,
|
27 |
+
const int rows_4,
|
28 |
+
const int rows_3,
|
29 |
+
const int rows_2
|
30 |
+
)
|
31 |
+
{
|
32 |
+
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
33 |
+
if (n >= size_n) return;
|
34 |
+
int k = 0;
|
35 |
+
uint32_t* b_ptr = b_q_weight + n;
|
36 |
+
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
|
37 |
+
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
|
38 |
+
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
|
39 |
+
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
40 |
+
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
|
41 |
+
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
// QMatrix constructor
|
46 |
+
|
47 |
+
QMatrix::QMatrix
|
48 |
+
(
|
49 |
+
const int _device,
|
50 |
+
const int _height,
|
51 |
+
const int _width,
|
52 |
+
const int _groups,
|
53 |
+
|
54 |
+
uint32_t* _q_weight,
|
55 |
+
uint16_t* _q_perm,
|
56 |
+
uint16_t* _q_invperm,
|
57 |
+
uint32_t* _q_scale,
|
58 |
+
half* _q_scale_max,
|
59 |
+
uint16_t* _q_groups,
|
60 |
+
|
61 |
+
uint32_t* _gptq_qzeros,
|
62 |
+
half* _gptq_scales,
|
63 |
+
uint32_t* _gptq_g_idx,
|
64 |
+
|
65 |
+
half* _temp_dq
|
66 |
+
) :
|
67 |
+
device(_device),
|
68 |
+
height(_height),
|
69 |
+
width(_width),
|
70 |
+
groups(_groups),
|
71 |
+
temp_dq(_temp_dq)
|
72 |
+
{
|
73 |
+
cudaSetDevice(device);
|
74 |
+
|
75 |
+
failed = false;
|
76 |
+
|
77 |
+
cuda_q_weight = _q_weight;
|
78 |
+
cuda_q_perm = _q_perm;
|
79 |
+
cuda_q_invperm = _q_invperm;
|
80 |
+
cuda_q_scale = _q_scale;
|
81 |
+
cuda_q_scale_max = _q_scale_max;
|
82 |
+
cuda_q_groups = _q_groups;
|
83 |
+
cuda_gptq_qzeros = _gptq_qzeros;
|
84 |
+
cuda_gptq_scales = _gptq_scales;
|
85 |
+
|
86 |
+
is_gptq = (_gptq_qzeros != NULL);
|
87 |
+
|
88 |
+
groupsize = 1;
|
89 |
+
while (groupsize * groups < height) groupsize *= 2;
|
90 |
+
|
91 |
+
// Create group map
|
92 |
+
|
93 |
+
rows_8 = 0;
|
94 |
+
rows_6 = 0;
|
95 |
+
rows_5 = 0;
|
96 |
+
rows_4 = 0;
|
97 |
+
rows_3 = 0;
|
98 |
+
rows_2 = 0;
|
99 |
+
|
100 |
+
if (!is_gptq)
|
101 |
+
{
|
102 |
+
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
103 |
+
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
104 |
+
|
105 |
+
for (int i = 0; i < groups; i++)
|
106 |
+
{
|
107 |
+
int bits = cpu_q_groups[i * 2];
|
108 |
+
if (bits == 8) rows_8 += groupsize;
|
109 |
+
if (bits == 6) rows_6 += groupsize;
|
110 |
+
if (bits == 5) rows_5 += groupsize;
|
111 |
+
if (bits == 4) rows_4 += groupsize;
|
112 |
+
if (bits == 3) rows_3 += groupsize;
|
113 |
+
if (bits == 2) rows_2 += groupsize;
|
114 |
+
}
|
115 |
+
|
116 |
+
free(cpu_q_groups);
|
117 |
+
|
118 |
+
rows_6 += rows_8;
|
119 |
+
rows_5 += rows_6;
|
120 |
+
rows_4 += rows_5;
|
121 |
+
rows_3 += rows_4;
|
122 |
+
rows_2 += rows_3;
|
123 |
+
}
|
124 |
+
else
|
125 |
+
{
|
126 |
+
rows_4 = height;
|
127 |
+
rows_3 = height;
|
128 |
+
rows_2 = height;
|
129 |
+
|
130 |
+
if (_gptq_g_idx)
|
131 |
+
{
|
132 |
+
if (!make_sequential(_gptq_g_idx))
|
133 |
+
{
|
134 |
+
failed = true;
|
135 |
+
//printf("FAIL\n");
|
136 |
+
return;
|
137 |
+
}
|
138 |
+
}
|
139 |
+
}
|
140 |
+
|
141 |
+
// Shuffle quantized data
|
142 |
+
|
143 |
+
dim3 blockDim, gridDim;
|
144 |
+
blockDim.x = THREADS_X;
|
145 |
+
blockDim.y = 1;
|
146 |
+
gridDim.x = DIVIDE(width, THREADS_X);
|
147 |
+
gridDim.y = 1;
|
148 |
+
|
149 |
+
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
150 |
+
}
|
151 |
+
|
152 |
+
QMatrix::~QMatrix()
|
153 |
+
{
|
154 |
+
}
|
155 |
+
|
156 |
+
// Reconstruct b[k,n] (GPTQ)
|
157 |
+
|
158 |
+
__global__ void reconstruct_gptq_kernel
|
159 |
+
(
|
160 |
+
const uint32_t* __restrict__ b_q_weight,
|
161 |
+
const uint16_t* __restrict__ b_q_perm,
|
162 |
+
const uint32_t* __restrict__ b_gptq_qzeros,
|
163 |
+
const half* __restrict__ b_gptq_scales,
|
164 |
+
//const uint16_t* __restrict__ b_q_groups,
|
165 |
+
const int size_k,
|
166 |
+
const int size_n,
|
167 |
+
const int groupsize,
|
168 |
+
const int groups,
|
169 |
+
half* __restrict__ b,
|
170 |
+
const int rows_4
|
171 |
+
)
|
172 |
+
{
|
173 |
+
MatrixView_half_rw b_(b, size_k, size_n);
|
174 |
+
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
175 |
+
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
176 |
+
|
177 |
+
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
178 |
+
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
179 |
+
|
180 |
+
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
181 |
+
|
182 |
+
// Preload remapping table
|
183 |
+
|
184 |
+
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
185 |
+
int t = threadIdx.x;
|
186 |
+
|
187 |
+
if (b_q_perm)
|
188 |
+
{
|
189 |
+
if (offset_k + t < size_k)
|
190 |
+
perm[t] = b_q_perm[offset_k + t];
|
191 |
+
}
|
192 |
+
|
193 |
+
// Column
|
194 |
+
|
195 |
+
int n = offset_n + t * 4;
|
196 |
+
if (n >= size_n) return;
|
197 |
+
|
198 |
+
// Find initial group
|
199 |
+
|
200 |
+
int group = offset_k / groupsize;
|
201 |
+
int nextgroup = offset_k + groupsize;
|
202 |
+
|
203 |
+
// b offset
|
204 |
+
|
205 |
+
int qk = offset_k / (32 / 4);
|
206 |
+
|
207 |
+
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
208 |
+
|
209 |
+
// Initial zeros/scale
|
210 |
+
|
211 |
+
int zeros[4];
|
212 |
+
half2 scales[4];
|
213 |
+
half2 z1z16[4][2];
|
214 |
+
half2 y1y16[4][2];
|
215 |
+
b_gptq_qzeros_.item4(zeros, group, n);
|
216 |
+
b_gptq_scales_.item4_h2(scales, group, n);
|
217 |
+
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
|
218 |
+
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
|
219 |
+
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
|
220 |
+
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
|
221 |
+
|
222 |
+
__syncthreads();
|
223 |
+
|
224 |
+
int k = offset_k;
|
225 |
+
int lk = 0;
|
226 |
+
|
227 |
+
while (k < end_k)
|
228 |
+
{
|
229 |
+
if (k == nextgroup)
|
230 |
+
{
|
231 |
+
group++;
|
232 |
+
nextgroup += groupsize;
|
233 |
+
b_gptq_qzeros_.item4(zeros, group, n);
|
234 |
+
b_gptq_scales_.item4_h2(scales, group, n);
|
235 |
+
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
|
236 |
+
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
|
237 |
+
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
|
238 |
+
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
|
239 |
+
}
|
240 |
+
|
241 |
+
for (int p = 0; p < 4; p++)
|
242 |
+
{
|
243 |
+
half2 dq[4][4];
|
244 |
+
const int4* b_ptr4 = (int4*) b_ptr;
|
245 |
+
int4 load_int4 = *b_ptr4;
|
246 |
+
|
247 |
+
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
248 |
+
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
249 |
+
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
250 |
+
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
251 |
+
|
252 |
+
b_ptr += size_n;
|
253 |
+
//half* dqh = (half*)dq;
|
254 |
+
if (b_q_perm)
|
255 |
+
{
|
256 |
+
for (int j = 0; j < 4; j++)
|
257 |
+
{
|
258 |
+
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
259 |
+
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
260 |
+
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
261 |
+
}
|
262 |
+
}
|
263 |
+
else
|
264 |
+
{
|
265 |
+
for (int j = 0; j < 4; j++)
|
266 |
+
{
|
267 |
+
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
268 |
+
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
269 |
+
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
270 |
+
}
|
271 |
+
}
|
272 |
+
}
|
273 |
+
k += 32;
|
274 |
+
}
|
275 |
+
}
|
276 |
+
|
277 |
+
|
278 |
+
// Reconstruct b[k,n]
|
279 |
+
|
280 |
+
__global__ void reconstruct_kernel
|
281 |
+
(
|
282 |
+
const uint32_t* __restrict__ b_q_weight,
|
283 |
+
const uint16_t* __restrict__ b_q_perm,
|
284 |
+
const uint32_t* __restrict__ b_q_scale,
|
285 |
+
const half* __restrict__ b_q_scale_max,
|
286 |
+
//const uint16_t* __restrict__ b_q_groups,
|
287 |
+
const int size_k,
|
288 |
+
const int size_n,
|
289 |
+
const int groupsize,
|
290 |
+
const int groups,
|
291 |
+
half* __restrict__ b,
|
292 |
+
const int rows_8,
|
293 |
+
const int rows_6,
|
294 |
+
const int rows_5,
|
295 |
+
const int rows_4,
|
296 |
+
const int rows_3,
|
297 |
+
const int rows_2
|
298 |
+
)
|
299 |
+
{
|
300 |
+
MatrixView_half_rw b_(b, size_k, size_n);
|
301 |
+
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
302 |
+
|
303 |
+
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
304 |
+
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
|
305 |
+
|
306 |
+
// Preload remapping table
|
307 |
+
|
308 |
+
int t = threadIdx.x;
|
309 |
+
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
310 |
+
if (offset_k + t < size_k)
|
311 |
+
perm[t] = b_q_perm[offset_k + t];
|
312 |
+
|
313 |
+
// Column
|
314 |
+
|
315 |
+
int n = offset_n + t;
|
316 |
+
if (n >= size_n) return;
|
317 |
+
|
318 |
+
// Find initial group
|
319 |
+
|
320 |
+
int group = offset_k / groupsize;
|
321 |
+
|
322 |
+
int pre_rows_8 = min(rows_8, offset_k);
|
323 |
+
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
324 |
+
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
325 |
+
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
326 |
+
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
327 |
+
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
328 |
+
int qk = 0;
|
329 |
+
qk += pre_rows_8 / 32 * 8;
|
330 |
+
qk += pre_rows_6 / 32 * 6;
|
331 |
+
qk += pre_rows_5 / 32 * 5;
|
332 |
+
qk += pre_rows_4 / 32 * 4;
|
333 |
+
qk += pre_rows_3 / 32 * 3;
|
334 |
+
qk += pre_rows_2 / 32 * 2;
|
335 |
+
|
336 |
+
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
337 |
+
|
338 |
+
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
339 |
+
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
340 |
+
int nextgroup = offset_k + groupsize;
|
341 |
+
|
342 |
+
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
343 |
+
int k = offset_k;
|
344 |
+
int lk = 0;
|
345 |
+
|
346 |
+
__syncthreads();
|
347 |
+
|
348 |
+
while (k < rows_8 && k < end_k)
|
349 |
+
{
|
350 |
+
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
351 |
+
for (int p = 0; p < 4; p++)
|
352 |
+
{
|
353 |
+
half2 dq[4];
|
354 |
+
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
355 |
+
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
356 |
+
dequant_8bit_8(q_0, q_1, dq, size_n);
|
357 |
+
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
358 |
+
half* dqh = (half*) dq;
|
359 |
+
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
360 |
+
}
|
361 |
+
k += 32;
|
362 |
+
}
|
363 |
+
|
364 |
+
while (k < rows_6 && k < end_k)
|
365 |
+
{
|
366 |
+
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
367 |
+
for (int p = 0; p < 2; p++)
|
368 |
+
{
|
369 |
+
half2 dq[8];
|
370 |
+
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
371 |
+
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
372 |
+
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
373 |
+
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
|
374 |
+
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
375 |
+
half* dqh = (half*) dq;
|
376 |
+
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
377 |
+
}
|
378 |
+
k += 32;
|
379 |
+
}
|
380 |
+
|
381 |
+
while (k < rows_5 && k < end_k)
|
382 |
+
{
|
383 |
+
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
384 |
+
for (int p = 0; p < 1; p++)
|
385 |
+
{
|
386 |
+
half2 dq[16];
|
387 |
+
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
388 |
+
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
389 |
+
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
390 |
+
uint32_t q_3 = *b_ptr; b_ptr += size_n;
|
391 |
+
uint32_t q_4 = *b_ptr; b_ptr += size_n;
|
392 |
+
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
|
393 |
+
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
394 |
+
half* dqh = (half*) dq;
|
395 |
+
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
396 |
+
}
|
397 |
+
k += 32;
|
398 |
+
}
|
399 |
+
|
400 |
+
while (k < rows_4 && k < end_k)
|
401 |
+
{
|
402 |
+
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
403 |
+
for (int p = 0; p < 4; p++)
|
404 |
+
{
|
405 |
+
half2 dq[4];
|
406 |
+
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
407 |
+
dequant_4bit_8(q_0, dq, size_n);
|
408 |
+
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
409 |
+
half* dqh = (half*) dq;
|
410 |
+
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
411 |
+
}
|
412 |
+
k += 32;
|
413 |
+
}
|
414 |
+
|
415 |
+
while (k < rows_3 && k < end_k)
|
416 |
+
{
|
417 |
+
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
418 |
+
for (int p = 0; p < 1; p++)
|
419 |
+
{
|
420 |
+
half2 dq[16];
|
421 |
+
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
422 |
+
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
423 |
+
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
424 |
+
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
|
425 |
+
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
426 |
+
half* dqh = (half*) dq;
|
427 |
+
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
428 |
+
}
|
429 |
+
k += 32;
|
430 |
+
}
|
431 |
+
|
432 |
+
while (k < rows_2 && k < end_k)
|
433 |
+
{
|
434 |
+
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
435 |
+
for (int p = 0; p < 2; p++)
|
436 |
+
{
|
437 |
+
half2 dq[8];
|
438 |
+
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
439 |
+
dequant_2bit_16(q_0, dq, size_n);
|
440 |
+
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
441 |
+
half* dqh = (half*) dq;
|
442 |
+
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
443 |
+
}
|
444 |
+
k += 32;
|
445 |
+
}
|
446 |
+
}
|
447 |
+
|
448 |
+
void QMatrix::reconstruct(half* out)
|
449 |
+
{
|
450 |
+
dim3 blockDim, gridDim;
|
451 |
+
blockDim.x = BLOCK_KN_SIZE;
|
452 |
+
blockDim.y = 1;
|
453 |
+
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
454 |
+
|
455 |
+
if (!is_gptq)
|
456 |
+
{
|
457 |
+
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
458 |
+
reconstruct_kernel<<<gridDim, blockDim>>>
|
459 |
+
(
|
460 |
+
cuda_q_weight,
|
461 |
+
cuda_q_perm,
|
462 |
+
cuda_q_scale,
|
463 |
+
cuda_q_scale_max,
|
464 |
+
//cuda_q_groups,
|
465 |
+
height,
|
466 |
+
width,
|
467 |
+
groupsize,
|
468 |
+
groups,
|
469 |
+
out,
|
470 |
+
rows_8,
|
471 |
+
rows_6,
|
472 |
+
rows_5,
|
473 |
+
rows_4,
|
474 |
+
rows_3,
|
475 |
+
rows_2
|
476 |
+
);
|
477 |
+
}
|
478 |
+
else
|
479 |
+
{
|
480 |
+
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
|
481 |
+
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
482 |
+
(
|
483 |
+
cuda_q_weight,
|
484 |
+
cuda_q_perm,
|
485 |
+
cuda_gptq_qzeros,
|
486 |
+
cuda_gptq_scales,
|
487 |
+
//const uint16_t* __restrict__ b_q_groups,
|
488 |
+
height,
|
489 |
+
width,
|
490 |
+
groupsize,
|
491 |
+
groups,
|
492 |
+
out,
|
493 |
+
rows_4
|
494 |
+
);
|
495 |
+
}
|
496 |
+
}
|
497 |
+
|
498 |
+
__global__ void make_sequential_kernel
|
499 |
+
(
|
500 |
+
const uint32_t* __restrict__ w,
|
501 |
+
uint32_t* __restrict__ w_new,
|
502 |
+
const uint16_t* __restrict__ q_perm,
|
503 |
+
const int w_height,
|
504 |
+
const int w_width
|
505 |
+
)
|
506 |
+
{
|
507 |
+
const uint64_t* w2 = (uint64_t*) w;
|
508 |
+
uint64_t* w_new2 = (uint64_t*) w_new;
|
509 |
+
int w2_stride = w_width >> 1;
|
510 |
+
|
511 |
+
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
512 |
+
if (w2_column >= w2_stride) return;
|
513 |
+
|
514 |
+
int w_new2_row = blockIdx.y;
|
515 |
+
|
516 |
+
int q_perm_idx = w_new2_row << 3;
|
517 |
+
|
518 |
+
uint64_t dst = 0;
|
519 |
+
|
520 |
+
#pragma unroll
|
521 |
+
for (int i = 0; i < 8; i++)
|
522 |
+
{
|
523 |
+
int source_row = q_perm[q_perm_idx++];
|
524 |
+
|
525 |
+
int w2_row = source_row >> 3;
|
526 |
+
int w2_subrow = source_row & 0x07;
|
527 |
+
int w2_row_shift = w2_subrow << 2;
|
528 |
+
int wnew2_row_shift = i << 2;
|
529 |
+
|
530 |
+
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
531 |
+
src >>= w2_row_shift;
|
532 |
+
src &= 0x0000000f0000000f;
|
533 |
+
src <<= wnew2_row_shift;
|
534 |
+
dst |= src;
|
535 |
+
}
|
536 |
+
|
537 |
+
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
538 |
+
}
|
539 |
+
|
540 |
+
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
541 |
+
{
|
542 |
+
uint32_t* cuda_new_qweight = NULL;
|
543 |
+
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
544 |
+
if (err != cudaSuccess) {
|
545 |
+
cudaError_t cuda_status = cudaGetLastError(); // Clear error
|
546 |
+
return false;
|
547 |
+
}
|
548 |
+
|
549 |
+
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
550 |
+
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
551 |
+
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
552 |
+
|
553 |
+
// Group histogram
|
554 |
+
|
555 |
+
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
556 |
+
|
557 |
+
// Group map
|
558 |
+
|
559 |
+
for (int i = 0, acc = 0; i < groups; i++)
|
560 |
+
{
|
561 |
+
short tmp = cpu_g_idx_map[i];
|
562 |
+
cpu_g_idx_map[i] = acc;
|
563 |
+
acc += tmp;
|
564 |
+
}
|
565 |
+
|
566 |
+
// X map (inverse)
|
567 |
+
|
568 |
+
for (int row = 0; row < height; row++)
|
569 |
+
{
|
570 |
+
uint32_t target_group = cpu_g_idx[row];
|
571 |
+
uint32_t target_row = cpu_g_idx_map[target_group];
|
572 |
+
cpu_g_idx_map[target_group]++;
|
573 |
+
cpu_x_map_inv[row] = target_row;
|
574 |
+
}
|
575 |
+
|
576 |
+
// X map
|
577 |
+
|
578 |
+
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
579 |
+
|
580 |
+
// Reduce to uint16_t
|
581 |
+
|
582 |
+
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
|
583 |
+
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
|
584 |
+
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
|
585 |
+
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
|
586 |
+
|
587 |
+
// Move to CUDA
|
588 |
+
|
589 |
+
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
590 |
+
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
591 |
+
|
592 |
+
// Rearrange rows in w
|
593 |
+
|
594 |
+
dim3 blockDim, gridDim;
|
595 |
+
blockDim.x = THREADS_X;
|
596 |
+
blockDim.y = 1;
|
597 |
+
gridDim.x = DIVIDE(width, THREADS_X);
|
598 |
+
gridDim.y = height / 8;
|
599 |
+
|
600 |
+
make_sequential_kernel<<<gridDim, blockDim>>>
|
601 |
+
(
|
602 |
+
cuda_q_weight,
|
603 |
+
cuda_new_qweight,
|
604 |
+
cuda_q_perm,
|
605 |
+
height / 8,
|
606 |
+
width
|
607 |
+
);
|
608 |
+
|
609 |
+
// Replace qweights
|
610 |
+
|
611 |
+
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
612 |
+
|
613 |
+
// Cleanup
|
614 |
+
|
615 |
+
cudaDeviceSynchronize();
|
616 |
+
|
617 |
+
cudaFree(cuda_new_qweight);
|
618 |
+
free(cpu_g_idx_map);
|
619 |
+
free(cpu_x_map);
|
620 |
+
free(cpu_x_map_inv);
|
621 |
+
|
622 |
+
return true;
|
623 |
+
}
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/q_matrix.cuh
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _q_matrix_cuh
|
2 |
+
#define _q_matrix_cuh
|
3 |
+
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
#include <cuda_fp16.h>
|
6 |
+
#include <cstdint>
|
7 |
+
#include <cstdio>
|
8 |
+
|
9 |
+
#define MAX_SUPERGROUPS 16
|
10 |
+
|
11 |
+
class QMatrix
|
12 |
+
{
|
13 |
+
public:
|
14 |
+
|
15 |
+
int device;
|
16 |
+
bool is_gptq;
|
17 |
+
|
18 |
+
int height;
|
19 |
+
int width;
|
20 |
+
int groups;
|
21 |
+
int groupsize;
|
22 |
+
|
23 |
+
int rows_8;
|
24 |
+
int rows_6;
|
25 |
+
int rows_5;
|
26 |
+
int rows_4;
|
27 |
+
int rows_3;
|
28 |
+
int rows_2;
|
29 |
+
|
30 |
+
uint32_t* cuda_q_weight = NULL;
|
31 |
+
uint16_t* cuda_q_perm = NULL;
|
32 |
+
uint16_t* cuda_q_invperm = NULL;
|
33 |
+
uint32_t* cuda_q_scale = NULL;
|
34 |
+
half* cuda_q_scale_max = NULL;
|
35 |
+
uint16_t* cuda_q_groups = NULL;
|
36 |
+
uint32_t* cuda_gptq_qzeros = NULL;
|
37 |
+
half* cuda_gptq_scales = NULL;
|
38 |
+
|
39 |
+
half* temp_dq;
|
40 |
+
|
41 |
+
bool failed;
|
42 |
+
|
43 |
+
QMatrix
|
44 |
+
(
|
45 |
+
const int _device,
|
46 |
+
const int _height,
|
47 |
+
const int _width,
|
48 |
+
const int _groups,
|
49 |
+
|
50 |
+
uint32_t* _q_weight,
|
51 |
+
uint16_t* _q_perm,
|
52 |
+
uint16_t* _q_invperm,
|
53 |
+
uint32_t* _q_scale,
|
54 |
+
half* _q_scale_max,
|
55 |
+
uint16_t* _q_groups,
|
56 |
+
|
57 |
+
uint32_t* _gptq_qzeros,
|
58 |
+
half* _gptq_scales,
|
59 |
+
uint32_t* _gptq_g_idx,
|
60 |
+
|
61 |
+
half* _temp_dq
|
62 |
+
);
|
63 |
+
|
64 |
+
~QMatrix();
|
65 |
+
|
66 |
+
void reconstruct(half* out);
|
67 |
+
bool make_sequential(const uint32_t* cpu_g_idx);
|
68 |
+
|
69 |
+
private:
|
70 |
+
|
71 |
+
};
|
72 |
+
|
73 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_2.cuh
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _qdq_2_cuh
|
2 |
+
#define _qdq_2_cuh
|
3 |
+
|
4 |
+
#include "qdq_util.cuh"
|
5 |
+
#include "../../config.h"
|
6 |
+
|
7 |
+
#if QMODE_2BIT == 1
|
8 |
+
|
9 |
+
// Permutation:
|
10 |
+
//
|
11 |
+
// ffddbb99 77553311 eeccaa88 66442200
|
12 |
+
|
13 |
+
__forceinline__ __device__ void shuffle_2bit_16
|
14 |
+
(
|
15 |
+
uint32_t* q,
|
16 |
+
int stride
|
17 |
+
)
|
18 |
+
{
|
19 |
+
uint32_t qa = q[0];
|
20 |
+
uint32_t qb = 0;
|
21 |
+
|
22 |
+
#pragma unroll
|
23 |
+
for (int i = 0; i < 8; i++)
|
24 |
+
{
|
25 |
+
uint32_t qa0 = qa & 0x03;
|
26 |
+
uint32_t qa1 = (qa & 0x0c) >> 2;
|
27 |
+
qa >>= 4;
|
28 |
+
qb |= (qa1 << (i * 2 + 16));
|
29 |
+
qb |= (qa0 << (i * 2));
|
30 |
+
}
|
31 |
+
q[0] = qb;
|
32 |
+
}
|
33 |
+
|
34 |
+
__forceinline__ __device__ void dequant_2bit_16
|
35 |
+
(
|
36 |
+
const uint32_t q_0,
|
37 |
+
half2 (&dq)[8],
|
38 |
+
int stride
|
39 |
+
)
|
40 |
+
{
|
41 |
+
const uint32_t c0 = 0x64006400;
|
42 |
+
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
43 |
+
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
44 |
+
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
45 |
+
const half2 y4 = __halves2half2(y4_, y4_);
|
46 |
+
const half2 y16 = __halves2half2(y16_, y16_);
|
47 |
+
const half2 y64 = __halves2half2(y64_, y64_);
|
48 |
+
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
|
49 |
+
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
|
50 |
+
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
|
51 |
+
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
|
52 |
+
const half2 z1 = __halves2half2(z1_, z1_);
|
53 |
+
const half2 z4 = __halves2half2(z4_, z4_);
|
54 |
+
const half2 z16 = __halves2half2(z16_, z16_);
|
55 |
+
const half2 z64 = __halves2half2(z64_, z64_);
|
56 |
+
|
57 |
+
uint32_t qa = q_0;
|
58 |
+
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
59 |
+
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
60 |
+
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
61 |
+
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
62 |
+
qa >>= 8;
|
63 |
+
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
64 |
+
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
65 |
+
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
66 |
+
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
67 |
+
|
68 |
+
dq[0] = __hadd2(q0.as_half2, z1);
|
69 |
+
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
70 |
+
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
71 |
+
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
72 |
+
dq[4] = __hadd2(q4.as_half2, z1);
|
73 |
+
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
74 |
+
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
75 |
+
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
76 |
+
}
|
77 |
+
|
78 |
+
#else
|
79 |
+
|
80 |
+
__forceinline__ __device__ void shuffle_2bit_16
|
81 |
+
(
|
82 |
+
uint32_t* q,
|
83 |
+
int stride
|
84 |
+
)
|
85 |
+
{
|
86 |
+
}
|
87 |
+
|
88 |
+
__forceinline__ __device__ void dequant_2bit_16
|
89 |
+
(
|
90 |
+
const uint32_t q_0,
|
91 |
+
half2 (&dq)[8],
|
92 |
+
int stride
|
93 |
+
)
|
94 |
+
{
|
95 |
+
half dqh[16];
|
96 |
+
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
|
97 |
+
|
98 |
+
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
99 |
+
}
|
100 |
+
|
101 |
+
#endif
|
102 |
+
|
103 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_3.cuh
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _qdq_3_cuh
|
2 |
+
#define _qdq_3_cuh
|
3 |
+
|
4 |
+
#include "qdq_util.cuh"
|
5 |
+
#include "../../config.h"
|
6 |
+
|
7 |
+
#if QMODE_3BIT == 1
|
8 |
+
|
9 |
+
// Permutation:
|
10 |
+
//
|
11 |
+
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
12 |
+
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
13 |
+
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
14 |
+
|
15 |
+
__forceinline__ __device__ void shuffle_3bit_32
|
16 |
+
(
|
17 |
+
uint32_t* q,
|
18 |
+
int stride
|
19 |
+
)
|
20 |
+
{
|
21 |
+
uint32_t qa = q[0 * stride];
|
22 |
+
uint32_t qb = q[1 * stride];
|
23 |
+
uint32_t qc = q[2 * stride];
|
24 |
+
|
25 |
+
// qa: aa999888 77766655 54443332 22111000
|
26 |
+
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
27 |
+
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
28 |
+
|
29 |
+
uint32_t qd = qc >> 26;
|
30 |
+
qc <<= 4;
|
31 |
+
qc |= qb >> 28;
|
32 |
+
qb <<= 2;
|
33 |
+
qb |= qa >> 30;
|
34 |
+
|
35 |
+
// qa: ..999888 77766655 54443332 22111000
|
36 |
+
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
37 |
+
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
38 |
+
// qd: vvvuuu
|
39 |
+
|
40 |
+
uint32_t za = 0;
|
41 |
+
uint32_t zb = 0;
|
42 |
+
uint32_t zc = 0;
|
43 |
+
|
44 |
+
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
45 |
+
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
46 |
+
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
47 |
+
|
48 |
+
// za: 9997775 55333111 8886664 44222000
|
49 |
+
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
50 |
+
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
51 |
+
// qd: vvvuuu
|
52 |
+
|
53 |
+
za |= ((qd & 0x01) >> 0) << 15;
|
54 |
+
zb |= ((qd & 0x02) >> 1) << 15;
|
55 |
+
zc |= ((qd & 0x04) >> 2) << 15;
|
56 |
+
za |= ((qd & 0x08) >> 3) << 31;
|
57 |
+
zb |= ((qd & 0x10) >> 4) << 31;
|
58 |
+
zc |= ((qd & 0x20) >> 5) << 31;
|
59 |
+
|
60 |
+
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
61 |
+
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
62 |
+
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
63 |
+
|
64 |
+
q[0 * stride] = za;
|
65 |
+
q[1 * stride] = zb;
|
66 |
+
q[2 * stride] = zc;
|
67 |
+
}
|
68 |
+
|
69 |
+
__forceinline__ __device__ void dequant_3bit_32
|
70 |
+
(
|
71 |
+
const uint32_t q_0,
|
72 |
+
const uint32_t q_1,
|
73 |
+
const uint32_t q_2,
|
74 |
+
half2 (&dq)[16],
|
75 |
+
int stride
|
76 |
+
)
|
77 |
+
{
|
78 |
+
const uint32_t c0 = 0x64006400;
|
79 |
+
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
80 |
+
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
81 |
+
const half2 y8 = __halves2half2(y8_, y8_);
|
82 |
+
const half2 y64 = __halves2half2(y64_, y64_);
|
83 |
+
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
|
84 |
+
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
|
85 |
+
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
|
86 |
+
const half2 z1 = __halves2half2(z1_, z1_);
|
87 |
+
const half2 z8 = __halves2half2(z8_, z8_);
|
88 |
+
const half2 z64 = __halves2half2(z64_, z64_);
|
89 |
+
|
90 |
+
uint32_t qa = q_0;
|
91 |
+
uint32_t qb = q_1;
|
92 |
+
uint32_t qc = q_2;
|
93 |
+
|
94 |
+
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
95 |
+
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
96 |
+
qa >>= 6;
|
97 |
+
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
98 |
+
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
99 |
+
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
100 |
+
qa >>= 9;
|
101 |
+
qa &= 0x00010001;
|
102 |
+
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
103 |
+
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
104 |
+
qb >>= 6;
|
105 |
+
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
106 |
+
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
107 |
+
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
108 |
+
qb >>= 8;
|
109 |
+
qb &= 0x00020002;
|
110 |
+
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
111 |
+
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
112 |
+
qc >>= 6;
|
113 |
+
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
114 |
+
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
115 |
+
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
116 |
+
qc >>= 7;
|
117 |
+
qc &= 0x00040004;
|
118 |
+
half2_uint32 q15((qa | qb | qc) | c0);
|
119 |
+
|
120 |
+
dq[ 0] = __hadd2( q0.as_half2, z1);
|
121 |
+
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
122 |
+
dq[ 2] = __hadd2( q2.as_half2, z1);
|
123 |
+
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
124 |
+
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
125 |
+
dq[ 5] = __hadd2( q5.as_half2, z1);
|
126 |
+
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
127 |
+
dq[ 7] = __hadd2( q7.as_half2, z1);
|
128 |
+
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
129 |
+
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
130 |
+
dq[10] = __hadd2(q10.as_half2, z1);
|
131 |
+
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
132 |
+
dq[12] = __hadd2(q12.as_half2, z1);
|
133 |
+
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
134 |
+
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
135 |
+
dq[15] = __hadd2(q15.as_half2, z1);
|
136 |
+
}
|
137 |
+
|
138 |
+
#else
|
139 |
+
|
140 |
+
__forceinline__ __device__ void shuffle_3bit_32
|
141 |
+
(
|
142 |
+
uint32_t* q,
|
143 |
+
int stride
|
144 |
+
)
|
145 |
+
{
|
146 |
+
}
|
147 |
+
|
148 |
+
__forceinline__ __device__ void dequant_3bit_32
|
149 |
+
(
|
150 |
+
const uint32_t q_0,
|
151 |
+
const uint32_t q_1,
|
152 |
+
const uint32_t q_2,
|
153 |
+
half2 (&dq)[16],
|
154 |
+
int stride
|
155 |
+
)
|
156 |
+
{
|
157 |
+
half dqh[32];
|
158 |
+
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
|
159 |
+
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
|
160 |
+
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
|
161 |
+
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
|
162 |
+
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
|
163 |
+
|
164 |
+
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
165 |
+
}
|
166 |
+
|
167 |
+
#endif
|
168 |
+
|
169 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_4.cuh
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _qdq_4_cuh
|
2 |
+
#define _qdq_4_cuh
|
3 |
+
|
4 |
+
#include "qdq_util.cuh"
|
5 |
+
#include "../../config.h"
|
6 |
+
|
7 |
+
#if QMODE_4BIT == 1
|
8 |
+
|
9 |
+
// Permutation:
|
10 |
+
//
|
11 |
+
// 77775555 33331111 66664444 22220000
|
12 |
+
|
13 |
+
__forceinline__ __device__ void shuffle_4bit_8
|
14 |
+
(
|
15 |
+
uint32_t* q,
|
16 |
+
int stride
|
17 |
+
)
|
18 |
+
{
|
19 |
+
uint32_t qa = q[0];
|
20 |
+
uint32_t qb = 0;
|
21 |
+
|
22 |
+
#pragma unroll
|
23 |
+
for (int i = 0; i < 4; i++)
|
24 |
+
{
|
25 |
+
uint32_t qa0 = qa & 0x0f;
|
26 |
+
uint32_t qa1 = (qa & 0xf0) >> 4;
|
27 |
+
qa >>= 8;
|
28 |
+
qb |= (qa1 << (i * 4 + 16));
|
29 |
+
qb |= (qa0 << (i * 4));
|
30 |
+
}
|
31 |
+
q[0] = qb;
|
32 |
+
}
|
33 |
+
|
34 |
+
__forceinline__ __device__ void dequant_4bit_8
|
35 |
+
(
|
36 |
+
const uint32_t q_0,
|
37 |
+
half2 (&dq)[4],
|
38 |
+
int stride
|
39 |
+
)
|
40 |
+
{
|
41 |
+
const uint32_t c0 = 0x64006400;
|
42 |
+
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
43 |
+
const half2 y16 = __halves2half2(y16_, y16_);
|
44 |
+
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
|
45 |
+
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
|
46 |
+
const half2 z1 = __halves2half2(z1_, z1_);
|
47 |
+
const half2 z16 = __halves2half2(z16_, z16_);
|
48 |
+
|
49 |
+
uint32_t qa = q_0;
|
50 |
+
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
51 |
+
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
52 |
+
qa >>= 8;
|
53 |
+
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
54 |
+
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
55 |
+
|
56 |
+
dq[0] = __hadd2(q0.as_half2, z1);
|
57 |
+
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
58 |
+
dq[2] = __hadd2(q2.as_half2, z1);
|
59 |
+
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
60 |
+
}
|
61 |
+
|
62 |
+
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
63 |
+
(
|
64 |
+
const uint32_t zero,
|
65 |
+
const half scale,
|
66 |
+
half2 (&z1z16)[2],
|
67 |
+
half2 (&y1y16)[2]
|
68 |
+
)
|
69 |
+
{
|
70 |
+
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
71 |
+
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
72 |
+
|
73 |
+
half2 scale2 = __half2half2(scale);
|
74 |
+
|
75 |
+
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
76 |
+
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
77 |
+
|
78 |
+
const half y1 = __float2half_rn(1.0f);
|
79 |
+
const half y16 = __float2half_rn(1.0f / 16.0f);
|
80 |
+
|
81 |
+
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
82 |
+
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
83 |
+
}
|
84 |
+
|
85 |
+
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
86 |
+
(
|
87 |
+
const uint32_t zero,
|
88 |
+
half2(&z1z16)[2],
|
89 |
+
half2(&y1y16)[2]
|
90 |
+
)
|
91 |
+
{
|
92 |
+
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
93 |
+
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
94 |
+
|
95 |
+
z1z16[0] = __half2half2(z1.as_half);
|
96 |
+
z1z16[1] = __half2half2(z16);
|
97 |
+
|
98 |
+
const half y1 = __float2half_rn(1.0f);
|
99 |
+
const half y16 = __float2half_rn(1.0f / 16.0f);
|
100 |
+
|
101 |
+
y1y16[0] = __half2half2(y1);
|
102 |
+
y1y16[1] = __half2half2(y16);
|
103 |
+
}
|
104 |
+
|
105 |
+
|
106 |
+
__forceinline__ __device__ void dequant_4bit_8_gptq
|
107 |
+
(
|
108 |
+
const uint32_t q_0,
|
109 |
+
half2 (&dq)[4],
|
110 |
+
half2 (&z1z16)[2],
|
111 |
+
half2 (&y1y16)[2],
|
112 |
+
int stride,
|
113 |
+
bool scaled
|
114 |
+
)
|
115 |
+
{
|
116 |
+
const uint32_t c0 = 0x64006400;
|
117 |
+
|
118 |
+
uint32_t qa = q_0;
|
119 |
+
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
120 |
+
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
121 |
+
qa >>= 8;
|
122 |
+
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
123 |
+
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
124 |
+
|
125 |
+
if (scaled)
|
126 |
+
{
|
127 |
+
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
128 |
+
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
129 |
+
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
130 |
+
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
131 |
+
}
|
132 |
+
else
|
133 |
+
{
|
134 |
+
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
135 |
+
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
136 |
+
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
137 |
+
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
138 |
+
}
|
139 |
+
}
|
140 |
+
|
141 |
+
#else
|
142 |
+
|
143 |
+
__forceinline__ __device__ void shuffle_4bit_8
|
144 |
+
(
|
145 |
+
uint32_t* q,
|
146 |
+
int stride
|
147 |
+
)
|
148 |
+
{
|
149 |
+
}
|
150 |
+
|
151 |
+
__forceinline__ __device__ void dequant_4bit_8
|
152 |
+
(
|
153 |
+
const uint32_t q_0,
|
154 |
+
half2 (&dq)[4],
|
155 |
+
int stride
|
156 |
+
)
|
157 |
+
{
|
158 |
+
half dqh[8];
|
159 |
+
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
|
160 |
+
|
161 |
+
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
162 |
+
}
|
163 |
+
|
164 |
+
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
165 |
+
(
|
166 |
+
const uint32_t zero,
|
167 |
+
const half scale,
|
168 |
+
half2 (&z1)[2],
|
169 |
+
half2 (&y1)[2]
|
170 |
+
)
|
171 |
+
{
|
172 |
+
half z = __int2half_rn(-((int)zero));
|
173 |
+
z = __hmul(z, scale);
|
174 |
+
z1[0] = __half2half2(z);
|
175 |
+
y1[0] = __half2half2(scale);
|
176 |
+
}
|
177 |
+
|
178 |
+
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
179 |
+
(
|
180 |
+
const uint32_t zero,
|
181 |
+
half2(&z1)[2],
|
182 |
+
half2(&y1)[2]
|
183 |
+
)
|
184 |
+
{
|
185 |
+
half z = __int2half_rn(-((int)zero));
|
186 |
+
z1[0] = __half2half2(z);
|
187 |
+
}
|
188 |
+
|
189 |
+
__forceinline__ __device__ void dequant_4bit_8_gptq
|
190 |
+
(
|
191 |
+
const uint32_t q_0,
|
192 |
+
half2 (&dq)[4],
|
193 |
+
half2 (&z1)[2],
|
194 |
+
half2 (&y1)[2],
|
195 |
+
int stride,
|
196 |
+
bool scaled
|
197 |
+
)
|
198 |
+
{
|
199 |
+
half2 dqh2[8];
|
200 |
+
|
201 |
+
uint32_t qa = q_0;
|
202 |
+
for (int i = 0; i < 4; i++)
|
203 |
+
{
|
204 |
+
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
205 |
+
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
206 |
+
dqh2[i] = __halves2half2(d0, d1);
|
207 |
+
}
|
208 |
+
|
209 |
+
if (scaled)
|
210 |
+
{
|
211 |
+
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
|
212 |
+
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
|
213 |
+
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
|
214 |
+
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
|
215 |
+
}
|
216 |
+
else
|
217 |
+
{
|
218 |
+
dq[0] = __hadd2(dqh2[0], z1[0]);
|
219 |
+
dq[1] = __hadd2(dqh2[1], z1[0]);
|
220 |
+
dq[2] = __hadd2(dqh2[2], z1[0]);
|
221 |
+
dq[3] = __hadd2(dqh2[3], z1[0]);
|
222 |
+
}
|
223 |
+
}
|
224 |
+
|
225 |
+
#endif
|
226 |
+
|
227 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_5.cuh
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _qdq_5_cuh
|
2 |
+
#define _qdq_5_cuh
|
3 |
+
|
4 |
+
#include "qdq_util.cuh"
|
5 |
+
#include "../../config.h"
|
6 |
+
|
7 |
+
#if QMODE_5BIT == 1
|
8 |
+
|
9 |
+
// Permutation:
|
10 |
+
//
|
11 |
+
// v5555533 33311111 u4444422 22200000 (u, v lsb)
|
12 |
+
// vbbbbb99 99977777 uaaaaa88 88866666
|
13 |
+
// vhhhhhff fffddddd ugggggee eeeccccc
|
14 |
+
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
15 |
+
// vtttttrr rrrppppp usssssqq qqqooooo
|
16 |
+
|
17 |
+
__forceinline__ __device__ void shuffle_5bit_32
|
18 |
+
(
|
19 |
+
uint32_t* q,
|
20 |
+
int stride
|
21 |
+
)
|
22 |
+
{
|
23 |
+
uint32_t qa = q[0 * stride];
|
24 |
+
uint32_t qb = q[1 * stride];
|
25 |
+
uint32_t qc = q[2 * stride];
|
26 |
+
uint32_t qd = q[3 * stride];
|
27 |
+
uint32_t qe = q[4 * stride];
|
28 |
+
|
29 |
+
// qa: 66555554 44443333 32222211 11100000
|
30 |
+
// qb: ccccbbbb baaaaa99 99988888 77777666
|
31 |
+
// qc: jiiiiihh hhhggggg fffffeee eedddddc
|
32 |
+
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
|
33 |
+
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
|
34 |
+
|
35 |
+
uint32_t qf = qe >> 22;
|
36 |
+
qe <<= 8;
|
37 |
+
qe |= qd >> 24;
|
38 |
+
qd <<= 6;
|
39 |
+
qd |= qc >> 26;
|
40 |
+
qc <<= 4;
|
41 |
+
qc |= qb >> 28;
|
42 |
+
qb <<= 2;
|
43 |
+
qb |= qa >> 30;
|
44 |
+
|
45 |
+
// qa: 555554 44443333 32222211 11100000
|
46 |
+
// qb: bbbbba aaaa9999 98888877 77766666
|
47 |
+
// qc: hhhhhg ggggffff feeeeedd dddccccc
|
48 |
+
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
|
49 |
+
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
|
50 |
+
// qf: vv vvvuuuuu
|
51 |
+
|
52 |
+
uint32_t za = 0;
|
53 |
+
uint32_t zb = 0;
|
54 |
+
uint32_t zc = 0;
|
55 |
+
uint32_t zd = 0;
|
56 |
+
uint32_t ze = 0;
|
57 |
+
|
58 |
+
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
|
59 |
+
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
|
60 |
+
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
|
61 |
+
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
|
62 |
+
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
|
63 |
+
|
64 |
+
// za: 5555533 33311111 4444422 22200000
|
65 |
+
// zb: bbbbb99 99977777 aaaaa88 88866666
|
66 |
+
// zc: hhhhhff fffddddd gggggee eeeccccc
|
67 |
+
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
|
68 |
+
// ze: tttttrr rrrppppp sssssqq qqqooooo
|
69 |
+
// qf: vv vvvuuuuu
|
70 |
+
|
71 |
+
za |= ((qf & 0x001) >> 0) << 15;
|
72 |
+
zb |= ((qf & 0x002) >> 1) << 15;
|
73 |
+
zc |= ((qf & 0x004) >> 2) << 15;
|
74 |
+
zd |= ((qf & 0x008) >> 3) << 15;
|
75 |
+
ze |= ((qf & 0x010) >> 4) << 15;
|
76 |
+
za |= ((qf & 0x020) >> 5) << 31;
|
77 |
+
zb |= ((qf & 0x040) >> 6) << 31;
|
78 |
+
zc |= ((qf & 0x080) >> 7) << 31;
|
79 |
+
zd |= ((qf & 0x100) >> 8) << 31;
|
80 |
+
ze |= ((qf & 0x200) >> 9) << 31;
|
81 |
+
|
82 |
+
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
|
83 |
+
// zb: vbbbbb99 99977777 uaaaaa88 88866666
|
84 |
+
// zc: vhhhhhff fffddddd ugggggee eeeccccc
|
85 |
+
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
86 |
+
// ze: vtttttrr rrrppppp usssssqq qqqooooo
|
87 |
+
|
88 |
+
q[0 * stride] = za;
|
89 |
+
q[1 * stride] = zb;
|
90 |
+
q[2 * stride] = zc;
|
91 |
+
q[3 * stride] = zd;
|
92 |
+
q[4 * stride] = ze;
|
93 |
+
}
|
94 |
+
|
95 |
+
__forceinline__ __device__ void dequant_5bit_32
|
96 |
+
(
|
97 |
+
const uint32_t q_0,
|
98 |
+
const uint32_t q_1,
|
99 |
+
const uint32_t q_2,
|
100 |
+
const uint32_t q_3,
|
101 |
+
const uint32_t q_4,
|
102 |
+
half2 (&dq)[16],
|
103 |
+
int stride
|
104 |
+
)
|
105 |
+
{
|
106 |
+
const uint32_t c0 = 0x64006400;
|
107 |
+
const half y32_ = __float2half_rn(1.0f / 32.0f);
|
108 |
+
const half2 y32 = __halves2half2(y32_, y32_);
|
109 |
+
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
|
110 |
+
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
|
111 |
+
const half2 z1 = __halves2half2(z1_, z1_);
|
112 |
+
const half2 z32 = __halves2half2(z32_, z32_);
|
113 |
+
|
114 |
+
uint32_t qa = q_0;
|
115 |
+
uint32_t qb = q_1;
|
116 |
+
uint32_t qc = q_2;
|
117 |
+
uint32_t qd = q_3;
|
118 |
+
uint32_t qe = q_4;
|
119 |
+
|
120 |
+
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
121 |
+
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
|
122 |
+
qa >>= 10;
|
123 |
+
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
124 |
+
qa >>= 5;
|
125 |
+
qa &= 0x00010001;
|
126 |
+
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
|
127 |
+
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
|
128 |
+
qb >>= 10;
|
129 |
+
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
|
130 |
+
qb >>= 4;
|
131 |
+
qb &= 0x00020002;
|
132 |
+
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
|
133 |
+
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
|
134 |
+
qc >>= 10;
|
135 |
+
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
|
136 |
+
qc >>= 3;
|
137 |
+
qc &= 0x00040004;
|
138 |
+
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
|
139 |
+
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
|
140 |
+
qd >>= 10;
|
141 |
+
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
|
142 |
+
qd >>= 2;
|
143 |
+
qd &= 0x00080008;
|
144 |
+
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
|
145 |
+
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
|
146 |
+
qe >>= 10;
|
147 |
+
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
|
148 |
+
qe >>= 1;
|
149 |
+
qe &= 0x00100010;
|
150 |
+
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
|
151 |
+
|
152 |
+
dq[ 0] = __hadd2( q0.as_half2, z1);
|
153 |
+
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
|
154 |
+
dq[ 2] = __hadd2( q2.as_half2, z1);
|
155 |
+
dq[ 3] = __hadd2( q3.as_half2, z1);
|
156 |
+
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
|
157 |
+
dq[ 5] = __hadd2( q5.as_half2, z1);
|
158 |
+
dq[ 6] = __hadd2( q6.as_half2, z1);
|
159 |
+
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
|
160 |
+
dq[ 8] = __hadd2( q8.as_half2, z1);
|
161 |
+
dq[ 9] = __hadd2( q9.as_half2, z1);
|
162 |
+
dq[10] = __hfma2(q10.as_half2, y32, z32);
|
163 |
+
dq[11] = __hadd2(q11.as_half2, z1);
|
164 |
+
dq[12] = __hadd2(q12.as_half2, z1);
|
165 |
+
dq[13] = __hfma2(q13.as_half2, y32, z32);
|
166 |
+
dq[14] = __hadd2(q14.as_half2, z1);
|
167 |
+
dq[15] = __hadd2(q15.as_half2, z1);
|
168 |
+
}
|
169 |
+
|
170 |
+
#else
|
171 |
+
|
172 |
+
__forceinline__ __device__ void shuffle_5bit_32
|
173 |
+
(
|
174 |
+
uint32_t* q,
|
175 |
+
int stride
|
176 |
+
)
|
177 |
+
{
|
178 |
+
}
|
179 |
+
|
180 |
+
__forceinline__ __device__ void dequant_5bit_32
|
181 |
+
(
|
182 |
+
const uint32_t q_0,
|
183 |
+
const uint32_t q_1,
|
184 |
+
const uint32_t q_2,
|
185 |
+
const uint32_t q_3,
|
186 |
+
const uint32_t q_4,
|
187 |
+
half2 (&dq)[16],
|
188 |
+
int stride
|
189 |
+
)
|
190 |
+
{
|
191 |
+
half dqh[32];
|
192 |
+
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
|
193 |
+
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
|
194 |
+
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
|
195 |
+
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
|
196 |
+
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
|
197 |
+
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
|
198 |
+
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
|
199 |
+
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
|
200 |
+
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
|
201 |
+
|
202 |
+
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
203 |
+
}
|
204 |
+
|
205 |
+
#endif
|
206 |
+
|
207 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_6.cuh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _qdq_6_cuh
|
2 |
+
#define _qdq_6_cuh
|
3 |
+
|
4 |
+
#include "qdq_util.cuh"
|
5 |
+
#include "../../config.h"
|
6 |
+
|
7 |
+
#if QMODE_6BIT == 1
|
8 |
+
|
9 |
+
// Not implemented
|
10 |
+
|
11 |
+
#else
|
12 |
+
|
13 |
+
__forceinline__ __device__ void shuffle_6bit_16
|
14 |
+
(
|
15 |
+
uint32_t* q,
|
16 |
+
int stride
|
17 |
+
)
|
18 |
+
{
|
19 |
+
}
|
20 |
+
|
21 |
+
__forceinline__ __device__ void dequant_6bit_16
|
22 |
+
(
|
23 |
+
const uint32_t q_0,
|
24 |
+
const uint32_t q_1,
|
25 |
+
const uint32_t q_2,
|
26 |
+
half2 (&dq)[8],
|
27 |
+
int stride
|
28 |
+
)
|
29 |
+
{
|
30 |
+
half dqh[16];
|
31 |
+
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
|
32 |
+
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
|
33 |
+
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
|
34 |
+
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
|
35 |
+
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
|
36 |
+
|
37 |
+
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
38 |
+
}
|
39 |
+
|
40 |
+
#endif
|
41 |
+
|
42 |
+
#endif
|
43 |
+
|
44 |
+
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_8.cuh
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _qdq_8_cuh
|
2 |
+
#define _qdq_8_cuh
|
3 |
+
|
4 |
+
#include "qdq_util.cuh"
|
5 |
+
#include "../../config.h"
|
6 |
+
|
7 |
+
#if QMODE_8BIT == 1
|
8 |
+
|
9 |
+
// Not implemented
|
10 |
+
|
11 |
+
#else
|
12 |
+
|
13 |
+
__forceinline__ __device__ void shuffle_8bit_4
|
14 |
+
(
|
15 |
+
uint32_t* q,
|
16 |
+
int stride
|
17 |
+
)
|
18 |
+
{
|
19 |
+
}
|
20 |
+
|
21 |
+
__forceinline__ __device__ void dequant_8bit_8
|
22 |
+
(
|
23 |
+
const uint32_t q_0,
|
24 |
+
const uint32_t q_1,
|
25 |
+
half2 (&dq)[4],
|
26 |
+
int stride
|
27 |
+
)
|
28 |
+
{
|
29 |
+
half dqh[8];
|
30 |
+
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
|
31 |
+
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
|
32 |
+
|
33 |
+
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
34 |
+
}
|
35 |
+
|
36 |
+
#endif
|
37 |
+
|
38 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/quant/qdq_util.cuh
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _qdq_util_cuh
|
2 |
+
#define _qdq_util_cuh
|
3 |
+
|
4 |
+
union half2_uint32
|
5 |
+
{
|
6 |
+
uint32_t as_uint32;
|
7 |
+
half2 as_half2;
|
8 |
+
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
9 |
+
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
10 |
+
};
|
11 |
+
|
12 |
+
union half_uint16
|
13 |
+
{
|
14 |
+
uint16_t as_uint16;
|
15 |
+
half as_half;
|
16 |
+
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
17 |
+
__device__ half_uint16(half val) : as_half(val) {}
|
18 |
+
};
|
19 |
+
|
20 |
+
// Max_scale premultiplied by 1/256
|
21 |
+
|
22 |
+
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
23 |
+
{
|
24 |
+
int qs_i = qs + 1;
|
25 |
+
half qs_h = __int2half_rn(qs_i * qs_i);
|
26 |
+
qs_h = __hmul(qs_h, max_scale);
|
27 |
+
return qs_h;
|
28 |
+
}
|
29 |
+
|
30 |
+
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
31 |
+
{
|
32 |
+
return __hmul(__int2half_rn(q - qzero), scale);
|
33 |
+
}
|
34 |
+
|
35 |
+
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
36 |
+
{
|
37 |
+
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
38 |
+
return __int2half_rn(q - qzero);
|
39 |
+
}
|
40 |
+
|
41 |
+
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
42 |
+
{
|
43 |
+
return (int)((q >> shift) & mask);
|
44 |
+
}
|
45 |
+
|
46 |
+
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
47 |
+
{
|
48 |
+
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
49 |
+
}
|
50 |
+
|
51 |
+
#endif
|
AutoAWQ_kernels/awq_ext/exllamav2/cuda/util.cuh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
3 |
+
|
4 |
+
#define DBGS(__x) printf("%s\n", __x)
|
5 |
+
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
6 |
+
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
7 |
+
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
8 |
+
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
|
9 |
+
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
|
10 |
+
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
|
11 |
+
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
12 |
+
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
13 |
+
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
14 |
+
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
|
15 |
+
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
|
16 |
+
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
|
17 |
+
|
18 |
+
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
|
19 |
+
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
|
20 |
+
|
21 |
+
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
|
22 |
+
{
|
23 |
+
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
|
24 |
+
qs_h = __hmul(qs_h, qs_h);
|
25 |
+
qs_h = __hmul(qs_h, max_scale);
|
26 |
+
return qs_h;
|
27 |
+
}
|
28 |
+
|
29 |
+
__forceinline__ __device__ float clamp(float x, float a, float b)
|
30 |
+
{
|
31 |
+
return fmaxf(a, fminf(b, x));
|
32 |
+
}
|
33 |
+
|
34 |
+
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
|
35 |
+
inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
|
36 |
+
{
|
37 |
+
if (code != cudaSuccess)
|
38 |
+
{
|
39 |
+
fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
|
40 |
+
if (abort) exit(code);
|
41 |
+
}
|
42 |
+
}
|
AutoAWQ_kernels/awq_ext/exllamav2/ext.cpp
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <c10/cuda/CUDAGuard.h>
|
3 |
+
#include <ATen/cuda/CUDAContext.h>
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
#include <cuda_fp16.h>
|
6 |
+
#include <cstdint>
|
7 |
+
#include <cstdio>
|
8 |
+
|
9 |
+
#include "config.h"
|
10 |
+
|
11 |
+
#include "cuda/q_matrix.cuh"
|
12 |
+
#include "cuda/q_gemm.cuh"
|
13 |
+
|
14 |
+
#include "cpp/util.h"
|
15 |
+
|
16 |
+
// Some decluttering macros
|
17 |
+
|
18 |
+
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
19 |
+
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
20 |
+
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
21 |
+
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
22 |
+
|
23 |
+
|
24 |
+
// Quant matrix
|
25 |
+
|
26 |
+
uintptr_t make_q_matrix
|
27 |
+
(
|
28 |
+
torch::Tensor q_weight,
|
29 |
+
torch::Tensor q_perm,
|
30 |
+
torch::Tensor q_invperm,
|
31 |
+
torch::Tensor q_scale,
|
32 |
+
torch::Tensor q_scale_max,
|
33 |
+
torch::Tensor q_groups,
|
34 |
+
torch::Tensor gptq_qzeros,
|
35 |
+
torch::Tensor gptq_scales,
|
36 |
+
torch::Tensor gptq_g_idx,
|
37 |
+
torch::Tensor temp_dq
|
38 |
+
)
|
39 |
+
{
|
40 |
+
TORCH_CHECK_DTYPE(q_weight, kInt);
|
41 |
+
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
|
42 |
+
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
|
43 |
+
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
44 |
+
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
45 |
+
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
46 |
+
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
47 |
+
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
48 |
+
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
49 |
+
|
50 |
+
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
|
51 |
+
|
52 |
+
int device = q_weight.device().index();
|
53 |
+
int width = q_weight.size(1);
|
54 |
+
int groups;
|
55 |
+
int height;
|
56 |
+
|
57 |
+
if (!q_scale.device().is_meta())
|
58 |
+
{
|
59 |
+
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
|
60 |
+
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
|
61 |
+
groups = q_scale.size(0);
|
62 |
+
height = q_invperm.size(0);
|
63 |
+
}
|
64 |
+
else
|
65 |
+
{
|
66 |
+
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
|
67 |
+
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
|
68 |
+
groups = gptq_qzeros.size(0);
|
69 |
+
height = q_weight.size(0) * 8;
|
70 |
+
}
|
71 |
+
|
72 |
+
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
|
73 |
+
|
74 |
+
QMatrix* m = new QMatrix
|
75 |
+
(
|
76 |
+
device,
|
77 |
+
height,
|
78 |
+
width,
|
79 |
+
groups,
|
80 |
+
(uint32_t*) q_weight.data_ptr(),
|
81 |
+
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
|
82 |
+
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
|
83 |
+
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
84 |
+
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
85 |
+
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
86 |
+
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
87 |
+
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
88 |
+
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
89 |
+
(half*) temp_dq.data_ptr()
|
90 |
+
);
|
91 |
+
|
92 |
+
return reinterpret_cast<uintptr_t> (m);
|
93 |
+
}
|
94 |
+
|
95 |
+
void gemm_half_q_half
|
96 |
+
(
|
97 |
+
torch::Tensor a,
|
98 |
+
uintptr_t b,
|
99 |
+
torch::Tensor c,
|
100 |
+
bool force_cuda
|
101 |
+
)
|
102 |
+
{
|
103 |
+
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
|
104 |
+
|
105 |
+
TORCH_CHECK_DTYPE(a, kHalf);
|
106 |
+
TORCH_CHECK_DTYPE(c, kHalf);
|
107 |
+
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
|
108 |
+
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
|
109 |
+
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
|
110 |
+
|
111 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
112 |
+
|
113 |
+
gemm_half_q_half_cuda
|
114 |
+
(
|
115 |
+
at::cuda::getCurrentCUDABlasHandle(),
|
116 |
+
(const half*) a.data_ptr(),
|
117 |
+
qm,
|
118 |
+
(half*) c.data_ptr(),
|
119 |
+
c.size(0), // m
|
120 |
+
c.size(1), // n
|
121 |
+
a.size(1), // k
|
122 |
+
true,
|
123 |
+
NULL,
|
124 |
+
force_cuda
|
125 |
+
);
|
126 |
+
}
|
127 |
+
|
128 |
+
// Bindings
|
129 |
+
|
130 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
131 |
+
{
|
132 |
+
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
|
133 |
+
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
|
134 |
+
}
|
AutoAWQ_kernels/awq_ext/layernorm/layernorm.cu
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
|
3 |
+
Adapted from NVIDIA FasterTransformer:
|
4 |
+
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
|
5 |
+
|
6 |
+
*/
|
7 |
+
|
8 |
+
#include <torch/extension.h>
|
9 |
+
#include <cuda_fp16.h>
|
10 |
+
#include "reduction.cuh"
|
11 |
+
#include "layernorm.h"
|
12 |
+
#include <cuda_runtime.h>
|
13 |
+
#include <c10/cuda/CUDAGuard.h>
|
14 |
+
|
15 |
+
static inline __device__ float to_float(half src)
|
16 |
+
{
|
17 |
+
return __half2float(src);
|
18 |
+
}
|
19 |
+
|
20 |
+
static inline __device__ float to_float(float src)
|
21 |
+
{
|
22 |
+
return src;
|
23 |
+
}
|
24 |
+
|
25 |
+
template<typename T>
|
26 |
+
__global__ void generalT5LayerNorm(
|
27 |
+
const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n)
|
28 |
+
{
|
29 |
+
// layernorm module in the T5 style No bias and no subtraction of mean.
|
30 |
+
const int tid = threadIdx.x;
|
31 |
+
|
32 |
+
__shared__ float s_variance;
|
33 |
+
float variance = 0.0f;
|
34 |
+
|
35 |
+
float local_var_sum = 0.0f;
|
36 |
+
for (int i = tid; i < n; i += blockDim.x) {
|
37 |
+
float diff = to_float(__ldg(&input[blockIdx.x * n + i]));
|
38 |
+
local_var_sum += diff * diff;
|
39 |
+
}
|
40 |
+
variance = blockReduceSum(local_var_sum);
|
41 |
+
|
42 |
+
if (threadIdx.x == 0) {
|
43 |
+
s_variance = rsqrtf(variance / (float)n + layernorm_eps);
|
44 |
+
}
|
45 |
+
__syncthreads();
|
46 |
+
|
47 |
+
for (int i = tid; i < n; i += blockDim.x) {
|
48 |
+
output[blockIdx.x * n + i] =
|
49 |
+
clamp_inf_for_half<T>((to_float(input[blockIdx.x * n + i]) * s_variance) * to_float(__ldg(&gamma[i])));
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
template<typename T>
|
55 |
+
void invokeGeneralT5LayerNorm(T* out,
|
56 |
+
const T* input,
|
57 |
+
const T* gamma,
|
58 |
+
// const T* beta,
|
59 |
+
const float layernorm_eps,
|
60 |
+
const int m,
|
61 |
+
const int n)
|
62 |
+
{
|
63 |
+
dim3 grid(m);
|
64 |
+
dim3 block(min(n, 1024));
|
65 |
+
|
66 |
+
/* For general cases, n is equal to hidden_units, e.g., 512/1024.
|
67 |
+
Since we have warp shuffle inside the code, block.x % 32 should be 0.
|
68 |
+
*/
|
69 |
+
if (n % 32 != 0) {
|
70 |
+
block.x = 1024;
|
71 |
+
}
|
72 |
+
|
73 |
+
block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x
|
74 |
+
|
75 |
+
/* should pay attention to the rsqrt precision*/
|
76 |
+
generalT5LayerNorm<T><<<grid, block>>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3
|
77 |
+
}
|
78 |
+
|
79 |
+
template void invokeGeneralT5LayerNorm(half* out,
|
80 |
+
const half* input,
|
81 |
+
const half* gamma,
|
82 |
+
// const half* beta,
|
83 |
+
const float layernorm_eps,
|
84 |
+
const int m,
|
85 |
+
const int n);
|
86 |
+
|
87 |
+
template void invokeGeneralT5LayerNorm(float* out,
|
88 |
+
const float* input,
|
89 |
+
const float* gamma,
|
90 |
+
// const half* beta,
|
91 |
+
const float layernorm_eps,
|
92 |
+
const int m,
|
93 |
+
const int n);
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
// input b, n, c
|
98 |
+
void layernorm_forward_cuda(
|
99 |
+
torch::Tensor _input,
|
100 |
+
torch::Tensor _gamma,
|
101 |
+
torch::Tensor _out,
|
102 |
+
float eps)
|
103 |
+
{
|
104 |
+
int m = _input.size(0) * _input.size(1);
|
105 |
+
int n = _input.size(2);
|
106 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(_input));
|
107 |
+
|
108 |
+
auto input = reinterpret_cast<half*>(_input.data_ptr<at::Half>());
|
109 |
+
auto gamma = reinterpret_cast<half*>(_gamma.data_ptr<at::Half>());
|
110 |
+
auto out = reinterpret_cast<half*>(_out.data_ptr<at::Half>());
|
111 |
+
|
112 |
+
invokeGeneralT5LayerNorm(out, input, gamma, eps, m, n);
|
113 |
+
}
|
AutoAWQ_kernels/awq_ext/layernorm/layernorm.h
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
void layernorm_forward_cuda(torch::Tensor _input, torch::Tensor _gamma, torch::Tensor _out, float eps);
|
AutoAWQ_kernels/awq_ext/layernorm/reduction.cuh
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
|
3 |
+
Adapted from NVIDIA FasterTransformer:
|
4 |
+
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
5 |
+
*/
|
6 |
+
|
7 |
+
#pragma once
|
8 |
+
#include <assert.h>
|
9 |
+
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
|
10 |
+
#include <cooperative_groups/reduce.h>
|
11 |
+
#else
|
12 |
+
#include <cooperative_groups.h>
|
13 |
+
#endif
|
14 |
+
#include <cuda_fp16.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
#include <float.h>
|
17 |
+
#include <type_traits>
|
18 |
+
|
19 |
+
#define HALF_FLT_MAX 65504.F
|
20 |
+
#define FINAL_MASK 0xffffffff
|
21 |
+
|
22 |
+
|
23 |
+
template<typename T>
|
24 |
+
inline __device__ T add(T a, T b) {
|
25 |
+
return a + b;
|
26 |
+
}
|
27 |
+
|
28 |
+
template<>
|
29 |
+
inline __device__ half2 add(half2 a, half2 b) {
|
30 |
+
return __hadd2(a, b);
|
31 |
+
}
|
32 |
+
|
33 |
+
template<>
|
34 |
+
inline __device__ half add(half a, half b) {
|
35 |
+
return __hadd(a, b);
|
36 |
+
}
|
37 |
+
|
38 |
+
template<typename T>
|
39 |
+
__inline__ __device__ T warpReduceSum(T val)
|
40 |
+
{
|
41 |
+
#pragma unroll
|
42 |
+
for (int mask = 16; mask > 0; mask >>= 1)
|
43 |
+
val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
|
44 |
+
return val;
|
45 |
+
}
|
46 |
+
|
47 |
+
/* Calculate the sum of all elements in a block */
|
48 |
+
template<typename T>
|
49 |
+
__inline__ __device__ T blockReduceSum(T val)
|
50 |
+
{
|
51 |
+
static __shared__ T shared[32];
|
52 |
+
int lane = threadIdx.x & 0x1f;
|
53 |
+
int wid = threadIdx.x >> 5;
|
54 |
+
|
55 |
+
val = warpReduceSum<T>(val);
|
56 |
+
|
57 |
+
if (lane == 0)
|
58 |
+
shared[wid] = val;
|
59 |
+
|
60 |
+
__syncthreads();
|
61 |
+
|
62 |
+
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
63 |
+
// blockDim.x is not divided by 32
|
64 |
+
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
65 |
+
val = warpReduceSum<T>(val);
|
66 |
+
|
67 |
+
return val;
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
template<typename T>
|
72 |
+
__device__ __forceinline__ T clamp_inf_for_half(const float input)
|
73 |
+
{
|
74 |
+
return input;
|
75 |
+
}
|
76 |
+
|
77 |
+
template<>
|
78 |
+
__device__ __forceinline__ half clamp_inf_for_half(const float input)
|
79 |
+
{
|
80 |
+
// clamp inf values to enable fp16 training
|
81 |
+
return input > 0.0f ? __float2half(min(input, HALF_FLT_MAX - 1000)) : __float2half(max(input, -HALF_FLT_MAX + 1000));
|
82 |
+
}
|