Eric Buehler commited on
Commit
05b1349
·
1 Parent(s): 6d6d594

Add metal kernels

Browse files
README.md CHANGED
@@ -1,3 +1,12 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/paged-attention)
8
+
9
+
10
+ ## attention
11
+
12
+ Paged attention kernels from [vLLM](https://github.com/vllm-project/).
build.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "paged_attention"
3
+
4
+ [torch]
5
+ src = [
6
+ "torch-ext/torch_binding.cpp",
7
+ "torch-ext/torch_binding.h"
8
+ ]
9
+
10
+ [kernel.activation_metal]
11
+ backend = "metal"
12
+ src = [
13
+ "paged-attention-metal/attention/paged_attention.metal",
14
+ "paged-attention-metal/cache/copy_blocks.metal",
15
+ "paged-attention-metal/cache/reshape_and_cache.metal",
16
+ "paged-attention-metal/utils.metal",
17
+ "paged-attention-metal/paged_attention.mm",
18
+ ]
19
+ depends = [ "torch" ]
flake.lock ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1733328505,
6
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs",
41
+ "rocm-nix": "rocm-nix"
42
+ },
43
+ "locked": {
44
+ "lastModified": 1744976941,
45
+ "narHash": "sha256-+csrhVaT6Mj2j1FM7P2BDITvf1Xwj2AKdMm0IKZK340=",
46
+ "owner": "huggingface",
47
+ "repo": "kernel-builder",
48
+ "rev": "0a278c2e9aaf6003a4ec6fe35c7158624762de5a",
49
+ "type": "github"
50
+ },
51
+ "original": {
52
+ "owner": "huggingface",
53
+ "repo": "kernel-builder",
54
+ "type": "github"
55
+ }
56
+ },
57
+ "nixpkgs": {
58
+ "locked": {
59
+ "lastModified": 1743559129,
60
+ "narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=",
61
+ "owner": "nixos",
62
+ "repo": "nixpkgs",
63
+ "rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e",
64
+ "type": "github"
65
+ },
66
+ "original": {
67
+ "owner": "nixos",
68
+ "ref": "nixos-unstable-small",
69
+ "repo": "nixpkgs",
70
+ "type": "github"
71
+ }
72
+ },
73
+ "rocm-nix": {
74
+ "inputs": {
75
+ "nixpkgs": [
76
+ "kernel-builder",
77
+ "nixpkgs"
78
+ ]
79
+ },
80
+ "locked": {
81
+ "lastModified": 1743085847,
82
+ "narHash": "sha256-uWG29p+nhZmGRV1LffWwRGjwtPIXeu1F0YTQbXgB+GU=",
83
+ "owner": "huggingface",
84
+ "repo": "rocm-nix",
85
+ "rev": "245cdc9bfb4bfafa818711c5f5e0b889afe1ba39",
86
+ "type": "github"
87
+ },
88
+ "original": {
89
+ "owner": "huggingface",
90
+ "repo": "rocm-nix",
91
+ "type": "github"
92
+ }
93
+ },
94
+ "root": {
95
+ "inputs": {
96
+ "kernel-builder": "kernel-builder"
97
+ }
98
+ },
99
+ "systems": {
100
+ "locked": {
101
+ "lastModified": 1681028828,
102
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
103
+ "owner": "nix-systems",
104
+ "repo": "default",
105
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "nix-systems",
110
+ "repo": "default",
111
+ "type": "github"
112
+ }
113
+ }
114
+ },
115
+ "root": "root",
116
+ "version": 7
117
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for attention kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
paged-attention-metal/attention/pagedattention.metal ADDED
@@ -0,0 +1,1187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Updated from MLX commit has f70764a
2
+
3
+ #include "utils.metal"
4
+ #include <metal_simdgroup>
5
+ #include <metal_stdlib>
6
+
7
+ using namespace metal;
8
+
9
+ // ========================================== Generic vector types
10
+
11
+ // A vector type to store Q, K, V elements.
12
+ template <typename T, int VEC_SIZE> struct Vec {};
13
+
14
+ // A vector type to store FP32 accumulators.
15
+ template <typename T> struct FloatVec {};
16
+
17
+ // Template vector operations.
18
+ template <typename Acc, typename A, typename B> inline Acc mul(A a, B b);
19
+
20
+ template <typename T> inline float sum(T v);
21
+
22
+ template <typename T> inline float dot(T a, T b) {
23
+ return sum(mul<T, T, T>(a, b));
24
+ }
25
+
26
+ template <typename A, typename T> inline float dot(T a, T b) {
27
+ return sum(mul<A, T, T>(a, b));
28
+ }
29
+
30
+ // FP32 vector data types.
31
+ struct Float8_ {
32
+ float4 x;
33
+ float4 y;
34
+ };
35
+
36
+ template <> struct Vec<float, 1> {
37
+ using Type = float;
38
+ };
39
+ template <> struct Vec<float, 2> {
40
+ using Type = float2;
41
+ };
42
+ template <> struct Vec<float, 4> {
43
+ using Type = float4;
44
+ };
45
+ template <> struct Vec<float, 8> {
46
+ using Type = Float8_;
47
+ };
48
+
49
+ template <> struct FloatVec<float> {
50
+ using Type = float;
51
+ };
52
+ template <> struct FloatVec<float2> {
53
+ using Type = float2;
54
+ };
55
+ template <> struct FloatVec<float4> {
56
+ using Type = float4;
57
+ };
58
+ template <> struct FloatVec<Float8_> {
59
+ using Type = Float8_;
60
+ };
61
+
62
+ template <> inline float mul(float a, float b) { return a * b; }
63
+
64
+ template <> inline float2 mul(float2 a, float2 b) { return a * b; }
65
+
66
+ template <> inline float4 mul(float4 a, float4 b) { return a * b; }
67
+
68
+ template <> inline Float8_ mul(Float8_ a, Float8_ b) {
69
+ Float8_ c;
70
+ c.x = a.x * b.x;
71
+ c.y = a.y * b.y;
72
+ return c;
73
+ }
74
+
75
+ template <> inline float sum(float a) { return a; }
76
+
77
+ template <> inline float sum(float2 a) { return a.x + a.y; }
78
+
79
+ template <> inline float sum(float4 a) { return a.x + a.y + a.z + a.w; }
80
+
81
+ template <> inline float sum(Float8_ a) { return sum(a.x) + sum(a.y); }
82
+
83
+ inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) {
84
+ Float8_ res;
85
+ res.x = fma(a.x, b.x, c.x);
86
+ res.y = fma(a.y, b.y, c.y);
87
+ return res;
88
+ }
89
+
90
+ inline void from_float(thread float &dst, float src) { dst = src; }
91
+ inline void from_float(thread float2 &dst, float2 src) { dst = src; }
92
+ inline void from_float(thread float4 &dst, float4 src) { dst = src; }
93
+ inline void from_float(thread Float8_ &dst, Float8_ src) { dst = src; }
94
+
95
+ // BF16 vector data types.
96
+ // #if defined(__HAVE_BFLOAT__)
97
+
98
+ // struct Bfloat8_ {
99
+ // bfloat4 x;
100
+ // bfloat4 y;
101
+ // };
102
+
103
+ // template<>
104
+ // struct Vec<bfloat, 1> {
105
+ // using Type = bfloat;
106
+ // };
107
+ // template<>
108
+ // struct Vec<bfloat, 2> {
109
+ // using Type = bfloat2;
110
+ // };
111
+ // template<>
112
+ // struct Vec<bfloat, 4> {
113
+ // using Type = bfloat4;
114
+ // };
115
+ // template<>
116
+ // struct Vec<bfloat, 8> {
117
+ // using Type = Bfloat8_;
118
+ // };
119
+
120
+ // template<>
121
+ // struct FloatVec<bfloat> {
122
+ // using Type = float;
123
+ // };
124
+ // template<>
125
+ // struct FloatVec<bfloat2> {
126
+ // using Type = float2;
127
+ // };
128
+ // template<>
129
+ // struct FloatVec<bfloat4> {
130
+ // using Type = float4;
131
+ // };
132
+ // template<>
133
+ // struct FloatVec<Bfloat8_> {
134
+ // using Type = Float8_;
135
+ // };
136
+
137
+ // template<>
138
+ // inline float mul(bfloat a, bfloat b) {
139
+ // return (float)a * (float)b;
140
+ // }
141
+ // template<>
142
+ // inline bfloat mul(bfloat a, bfloat b) {
143
+ // return a*b;
144
+ // }
145
+
146
+ // template<>
147
+ // inline float2 mul(bfloat2 a, bfloat2 b) {
148
+ // return (float2)a * (float2)b;
149
+ // }
150
+ // template<>
151
+ // inline bfloat2 mul(bfloat2 a, bfloat2 b) {
152
+ // return a * b;
153
+ // }
154
+
155
+ // template<>
156
+ // inline float4 mul(bfloat4 a, bfloat4 b) {
157
+ // return (float4)a * (float4)b;
158
+ // }
159
+ // template<>
160
+ // inline bfloat4 mul(bfloat4 a, bfloat4 b) {
161
+ // return a * b;
162
+ // }
163
+
164
+ // template<>
165
+ // inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) {
166
+ // Float8_ c;
167
+ // c.x = mul<float4, bfloat4, bfloat4>(a.x, b.x);
168
+ // c.y = mul<float4, bfloat4, bfloat4>(a.y, b.y);
169
+ // return c;
170
+ // }
171
+ // template<>
172
+ // inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) {
173
+ // Bfloat8_ c;
174
+ // c.x = mul<bfloat4, bfloat4, bfloat4>(a.x, b.x);
175
+ // c.y = mul<bfloat4, bfloat4, bfloat4>(a.y, b.y);
176
+ // return c;
177
+ // }
178
+
179
+ // template<>
180
+ // inline float sum(bfloat a) {
181
+ // return (float)a;
182
+ // }
183
+
184
+ // template<>
185
+ // inline float sum(bfloat2 a) {
186
+ // return (float)a.x + (float)a.y;
187
+ // }
188
+
189
+ // template<>
190
+ // inline float sum(bfloat4 a) {
191
+ // return sum(a.x) + sum(a.y);
192
+ // }
193
+
194
+ // template<>
195
+ // inline float sum(Bfloat8_ a) {
196
+ // return sum(a.x) + sum(a.y);
197
+ // }
198
+
199
+ // inline float fma(bfloat a, bfloat b, float c) {
200
+ // return (float)a * (float)b + c;
201
+ // }
202
+
203
+ // inline float2 fma(bfloat2 a, bfloat2 b, float2 c) {
204
+ // return (float2)a * (float2)b + c;
205
+ // }
206
+
207
+ // inline float4 fma(bfloat4 a, bfloat4 b, float4 c) {
208
+ // return (float4)a * (float4)b + c;
209
+ // }
210
+
211
+ // inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) {
212
+ // Float8_ res;
213
+ // res.x = fma((float4)a.x, (float4)b.x, (float4)c.x);
214
+ // res.y = fma((float4)a.y, (float4)b.y, (float4)c.y);
215
+ // return res;
216
+ // }
217
+ // inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) {
218
+ // Bfloat8_ res;
219
+ // res.x = (bfloat4)fma((float4)a.x, (float4)b.x, (float4)c.x);
220
+ // res.y = (bfloat4)fma((float4)a.y, (float4)b.x, (float4)c.y);
221
+ // return c;
222
+ // }
223
+
224
+ // inline void from_float(thread bfloat& dst, float src) {
225
+ // dst = static_cast<bfloat>(src);
226
+ // }
227
+ // inline void from_float(thread bfloat2& dst, float2 src) {
228
+ // dst.x = static_cast<bfloat>(src.x);
229
+ // dst.y = static_cast<bfloat>(src.y);
230
+ // }
231
+ // inline void from_float(thread bfloat4& dst, float4 src) {
232
+ // dst.x = static_cast<bfloat>(src.x);
233
+ // dst.y = static_cast<bfloat>(src.y);
234
+ // dst.z = static_cast<bfloat>(src.z);
235
+ // dst.w = static_cast<bfloat>(src.w);
236
+ // }
237
+ // inline void from_float(thread Bfloat8_& dst, Float8_ src) {
238
+ // bfloat4 x;
239
+ // bfloat4 y;
240
+ // from_float(x, src.x);
241
+ // from_float(y, src.y);
242
+ // dst.x = x;
243
+ // dst.y = y;
244
+ // }
245
+
246
+ // #else
247
+
248
+ struct Bfloat2_ {
249
+ bfloat16_t x;
250
+ bfloat16_t y;
251
+ };
252
+
253
+ struct Bfloat4_ {
254
+ Bfloat2_ x;
255
+ Bfloat2_ y;
256
+ };
257
+
258
+ struct Bfloat8_ {
259
+ Bfloat4_ x;
260
+ Bfloat4_ y;
261
+ };
262
+
263
+ template <> struct Vec<bfloat16_t, 1> {
264
+ using Type = bfloat16_t;
265
+ };
266
+ template <> struct Vec<bfloat16_t, 2> {
267
+ using Type = Bfloat2_;
268
+ };
269
+ template <> struct Vec<bfloat16_t, 4> {
270
+ using Type = Bfloat4_;
271
+ };
272
+ template <> struct Vec<bfloat16_t, 8> {
273
+ using Type = Bfloat8_;
274
+ };
275
+
276
+ template <> struct FloatVec<bfloat16_t> {
277
+ using Type = float;
278
+ };
279
+ template <> struct FloatVec<Bfloat2_> {
280
+ using Type = float2;
281
+ };
282
+ template <> struct FloatVec<Bfloat4_> {
283
+ using Type = float4;
284
+ };
285
+ template <> struct FloatVec<Bfloat8_> {
286
+ using Type = Float8_;
287
+ };
288
+
289
+ template <> inline float mul(bfloat16_t a, bfloat16_t b) {
290
+ return (float)a * (float)b;
291
+ }
292
+ template <> inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { return a * b; }
293
+
294
+ template <> inline float2 mul(Bfloat2_ a, Bfloat2_ b) {
295
+ float2 a_f((float)a.x, (float)a.y);
296
+ float2 b_f((float)b.x, (float)b.y);
297
+ return a_f * b_f;
298
+ }
299
+ template <> inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) {
300
+ Bfloat2_ c;
301
+ c.x = a.x * b.x;
302
+ c.y = a.y * b.y;
303
+ return c;
304
+ }
305
+
306
+ template <> inline float4 mul(Bfloat4_ a, Bfloat4_ b) {
307
+ float2 x = mul<float2, Bfloat2_, Bfloat2_>(a.x, b.x);
308
+ float2 y = mul<float2, Bfloat2_, Bfloat2_>(a.y, b.y);
309
+ float4 c;
310
+ c.x = x.x;
311
+ c.y = x.y;
312
+ c.z = y.x;
313
+ c.w = y.y;
314
+ return c;
315
+ }
316
+ template <> inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) {
317
+ Bfloat4_ c;
318
+ c.x = mul<Bfloat2_, Bfloat2_, Bfloat2_>(a.x, b.x);
319
+ c.y = mul<Bfloat2_, Bfloat2_, Bfloat2_>(a.y, b.y);
320
+ return c;
321
+ }
322
+
323
+ template <> inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) {
324
+ Float8_ c;
325
+ c.x = mul<float4, Bfloat4_, Bfloat4_>(a.x, b.x);
326
+ c.y = mul<float4, Bfloat4_, Bfloat4_>(a.y, b.y);
327
+ return c;
328
+ }
329
+ template <> inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) {
330
+ Bfloat8_ c;
331
+ c.x = mul<Bfloat4_, Bfloat4_, Bfloat4_>(a.x, b.x);
332
+ c.y = mul<Bfloat4_, Bfloat4_, Bfloat4_>(a.y, b.y);
333
+ return c;
334
+ }
335
+
336
+ template <> inline float sum(bfloat16_t a) { return (float)a; }
337
+
338
+ template <> inline float sum(Bfloat2_ a) { return (float)a.x + (float)a.y; }
339
+
340
+ template <> inline float sum(Bfloat4_ a) { return sum(a.x) + sum(a.y); }
341
+
342
+ template <> inline float sum(Bfloat8_ a) { return sum(a.x) + sum(a.y); }
343
+
344
+ inline float fma(bfloat16_t a, bfloat16_t b, float c) {
345
+ return (float)a * (float)b + c;
346
+ }
347
+ inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) {
348
+ return a * b + c;
349
+ }
350
+
351
+ inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) {
352
+ float2 a_f((float)a.x, (float)a.y);
353
+ float2 b_f((float)b.x, (float)b.y);
354
+ return a_f * b_f + c;
355
+ }
356
+ inline Bfloat2_ fma(Bfloat2_ a, Bfloat2_ b, Bfloat2_ c) {
357
+ Bfloat2_ res;
358
+ res.x = a.x * b.x + c.x;
359
+ res.y = a.y * b.y + c.y;
360
+ return res;
361
+ }
362
+
363
+ inline float4 fma(Bfloat4_ a, Bfloat4_ b, float4 c) {
364
+ float4 res;
365
+ res.x = fma(a.x.x, b.x.x, c.x);
366
+ res.y = fma(a.x.y, b.x.y, c.y);
367
+ res.z = fma(a.y.x, b.y.x, c.z);
368
+ res.w = fma(a.y.y, b.y.y, c.w);
369
+ return res;
370
+ }
371
+ inline Bfloat4_ fma(Bfloat4_ a, Bfloat4_ b, Bfloat4_ c) {
372
+ Bfloat4_ res;
373
+ res.x = fma(a.x, b.x, c.x);
374
+ res.y = fma(a.y, b.y, c.y);
375
+ return res;
376
+ }
377
+
378
+ inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) {
379
+ float4 x = fma(a.x, b.x, c.x);
380
+ float4 y = fma(a.y, b.y, c.y);
381
+ Float8_ res;
382
+ res.x = x;
383
+ res.y = y;
384
+ return res;
385
+ }
386
+ inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) {
387
+ Bfloat8_ res;
388
+ res.x = fma(a.x, b.x, c.x);
389
+ res.y = fma(a.y, b.y, c.y);
390
+ return res;
391
+ }
392
+
393
+ inline void from_float(thread bfloat16_t &dst, float src) {
394
+ dst = static_cast<bfloat16_t>(src);
395
+ }
396
+ inline void from_float(thread Bfloat2_ &dst, float2 src) {
397
+ dst.x = static_cast<bfloat16_t>(src.x);
398
+ dst.y = static_cast<bfloat16_t>(src.y);
399
+ }
400
+ inline void from_float(thread Bfloat4_ &dst, float4 src) {
401
+ dst.x.x = static_cast<bfloat16_t>(src.x);
402
+ dst.x.y = static_cast<bfloat16_t>(src.y);
403
+ dst.y.x = static_cast<bfloat16_t>(src.z);
404
+ dst.y.y = static_cast<bfloat16_t>(src.w);
405
+ }
406
+ inline void from_float(thread Bfloat8_ &dst, Float8_ src) {
407
+ Bfloat4_ x;
408
+ Bfloat4_ y;
409
+ from_float(x, src.x);
410
+ from_float(y, src.y);
411
+ dst.x = x;
412
+ dst.y = y;
413
+ }
414
+
415
+ // #endif
416
+
417
+ // FP16 vector data types.
418
+ struct Half8_ {
419
+ half4 x;
420
+ half4 y;
421
+ };
422
+
423
+ template <> struct Vec<half, 1> {
424
+ using Type = half;
425
+ };
426
+ template <> struct Vec<half, 2> {
427
+ using Type = half2;
428
+ };
429
+ template <> struct Vec<half, 4> {
430
+ using Type = half4;
431
+ };
432
+ template <> struct Vec<half, 8> {
433
+ using Type = Half8_;
434
+ };
435
+
436
+ template <> struct FloatVec<half> {
437
+ using Type = float;
438
+ };
439
+ template <> struct FloatVec<half2> {
440
+ using Type = float2;
441
+ };
442
+ template <> struct FloatVec<half4> {
443
+ using Type = float4;
444
+ };
445
+ template <> struct FloatVec<Half8_> {
446
+ using Type = Float8_;
447
+ };
448
+
449
+ template <> inline float mul(half a, half b) { return (float)a * (float)b; }
450
+ template <> inline half mul(half a, half b) { return a * b; }
451
+
452
+ template <> inline float2 mul(half2 a, half2 b) {
453
+ return (float2)a * (float2)b;
454
+ }
455
+ template <> inline half2 mul(half2 a, half2 b) { return a * b; }
456
+
457
+ template <> inline float4 mul(half4 a, half4 b) {
458
+ return (float4)a * (float4)b;
459
+ }
460
+ template <> inline half4 mul(half4 a, half4 b) { return a * b; }
461
+
462
+ template <> inline Float8_ mul(Half8_ a, Half8_ b) {
463
+ float4 x = mul<float4, half4, half4>(a.x, b.x);
464
+ float4 y = mul<float4, half4, half4>(a.y, b.y);
465
+ Float8_ c;
466
+ c.x = x;
467
+ c.y = y;
468
+ return c;
469
+ }
470
+ template <> inline Half8_ mul(Half8_ a, Half8_ b) {
471
+ Half8_ c;
472
+ c.x = mul<half4, half4, half4>(a.x, b.x);
473
+ c.y = mul<half4, half4, half4>(a.y, b.y);
474
+ return c;
475
+ }
476
+
477
+ template <> inline float sum(half a) { return (float)a; }
478
+
479
+ template <> inline float sum(half2 a) { return (float)a.x + (float)a.y; }
480
+
481
+ template <> inline float sum(half4 a) { return a.x + a.y + a.z + a.w; }
482
+
483
+ template <> inline float sum(Half8_ a) { return sum(a.x) + sum(a.y); }
484
+
485
+ inline float fma(half a, half b, float c) { return (float)a * (float)b + c; }
486
+
487
+ inline float2 fma(half2 a, half2 b, float2 c) {
488
+ return (float2)a * (float2)b + c;
489
+ }
490
+
491
+ inline float4 fma(half4 a, half4 b, float4 c) {
492
+ return (float4)a * (float4)b + c;
493
+ }
494
+
495
+ inline Float8_ fma(Half8_ a, Half8_ b, Float8_ c) {
496
+ float4 x = fma(a.x, b.x, c.x);
497
+ float4 y = fma(a.y, b.y, c.y);
498
+ Float8_ res;
499
+ res.x = x;
500
+ res.y = y;
501
+ return res;
502
+ }
503
+ inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) {
504
+ Half8_ res;
505
+ res.x = fma(a.x, b.x, c.x);
506
+ res.y = fma(a.y, b.y, c.y);
507
+ return res;
508
+ }
509
+
510
+ inline void from_float(thread half &dst, float src) {
511
+ dst = static_cast<half>(src);
512
+ }
513
+ inline void from_float(thread half2 &dst, float2 src) {
514
+ dst.x = static_cast<half>(src.x);
515
+ dst.y = static_cast<half>(src.y);
516
+ }
517
+ inline void from_float(thread half4 &dst, float4 src) {
518
+ dst.x = static_cast<half>(src.x);
519
+ dst.y = static_cast<half>(src.y);
520
+ dst.z = static_cast<half>(src.z);
521
+ dst.w = static_cast<half>(src.w);
522
+ }
523
+ inline void from_float(thread Half8_ &dst, Float8_ src) {
524
+ half4 x;
525
+ half4 y;
526
+ from_float(x, src.x);
527
+ from_float(y, src.y);
528
+ dst.x = x;
529
+ dst.y = y;
530
+ }
531
+
532
+ // ========================================== Dot product utilities
533
+
534
+ // TODO(EricLBuehler): optimize with vectorization
535
+ template <int THREAD_GROUP_SIZE, typename Vec, int N>
536
+ inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) {
537
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
538
+ using A_vec = typename FloatVec<Vec>::Type;
539
+ A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
540
+ #pragma unroll
541
+ for (int ii = 1; ii < N; ++ii) {
542
+ qk_vec = fma(q[ii], k[ii], qk_vec);
543
+ }
544
+
545
+ // Finalize the reduction across lanes.
546
+ float qk = sum(qk_vec);
547
+ #pragma unroll
548
+ for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
549
+ qk += simd_shuffle_xor(qk, mask);
550
+ }
551
+ return qk;
552
+ }
553
+
554
+ template <typename T, int THREAD_GROUP_SIZE> struct Qk_dot {
555
+ template <typename Vec, int N>
556
+ static inline float dot(const threadgroup Vec (&q)[N],
557
+ const thread Vec (&k)[N]) {
558
+ return qk_dot_<THREAD_GROUP_SIZE>(q, k);
559
+ }
560
+ };
561
+
562
+ // ========================================== Block sum utility
563
+
564
+ // Utility function for attention softmax.
565
+ template <int NUM_WARPS, int NUM_SIMD_LANES>
566
+ inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid,
567
+ uint simd_lid) {
568
+ // Compute the sum per simdgroup.
569
+ #pragma unroll
570
+ for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) {
571
+ sum += simd_shuffle_xor(sum, mask);
572
+ }
573
+
574
+ // Simd leaders store the data to shared memory.
575
+ if (simd_lid == 0) {
576
+ red_smem[simd_tid] = sum;
577
+ }
578
+
579
+ // Make sure the data is in shared memory.
580
+ threadgroup_barrier(mem_flags::mem_threadgroup);
581
+
582
+ // The warps compute the final sums.
583
+ if (simd_lid < NUM_WARPS) {
584
+ sum = red_smem[simd_lid];
585
+ }
586
+
587
+ // Parallel reduction inside the simd group.
588
+ #pragma unroll
589
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
590
+ sum += simd_shuffle_xor(sum, mask);
591
+ }
592
+
593
+ // Broadcast to other threads.
594
+ return simd_shuffle(sum, 0);
595
+ }
596
+
597
+ // ========================================== Paged Attention kernel
598
+
599
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
600
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
601
+ #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
602
+
603
+ constant bool use_partitioning [[function_constant(10)]];
604
+ constant bool use_alibi [[function_constant(20)]];
605
+
606
+ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
607
+ int NUM_SIMD_LANES, int PARTITION_SIZE = 0>
608
+ [[kernel]] void paged_attention(
609
+ device float *exp_sums
610
+ [[buffer(0), function_constant(use_partitioning)]], // [num_seqs, num_heads,
611
+ // max_num_partitions]
612
+ device float *max_logits
613
+ [[buffer(1), function_constant(use_partitioning)]], // [num_seqs, num_heads,
614
+ // max_num_partitions]
615
+ device T *out
616
+ [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size]
617
+ device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size]
618
+ device const T *k_cache
619
+ [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x]
620
+ device const T *v_cache
621
+ [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size]
622
+ const constant int &num_kv_heads [[buffer(6)]], // [num_heads]
623
+ const constant float &scale [[buffer(7)]],
624
+ const constant float &softcapping [[buffer(8)]],
625
+ device const uint32_t *block_tables
626
+ [[buffer(9)]], // [num_seqs, max_num_blocks_per_seq]
627
+ device const uint32_t *context_lens [[buffer(10)]], // [num_seqs]
628
+ const constant int &max_num_blocks_per_seq [[buffer(11)]],
629
+ device const float *alibi_slopes
630
+ [[buffer(12), function_constant(use_alibi)]], // [num_heads]
631
+ const constant int &q_stride [[buffer(13)]],
632
+ const constant int &kv_block_stride [[buffer(14)]],
633
+ const constant int &kv_head_stride [[buffer(15)]],
634
+ threadgroup char *shared_mem [[threadgroup(0)]],
635
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
636
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]],
637
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
638
+ uint simd_tid [[simdgroup_index_in_threadgroup]],
639
+ uint simd_lid [[thread_index_in_simdgroup]]) {
640
+ const int seq_idx = threadgroup_position_in_grid.y;
641
+ const int partition_idx = threadgroup_position_in_grid.z;
642
+ const int max_num_partitions = threadgroups_per_grid.z;
643
+ const int thread_idx = thread_position_in_threadgroup.x;
644
+ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
645
+ const uint32_t context_len = context_lens[seq_idx];
646
+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
647
+ // No work to do. Terminate the thread block.
648
+ return;
649
+ }
650
+
651
+ const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
652
+ const int num_blocks_per_partition =
653
+ USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
654
+
655
+ // [start_block_idx, end_block_idx) is the range of blocks to process.
656
+ const int start_block_idx =
657
+ USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
658
+ const int end_block_idx =
659
+ MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
660
+ const int num_blocks = end_block_idx - start_block_idx;
661
+
662
+ // [start_token_idx, end_token_idx) is the range of tokens to process.
663
+ const int start_token_idx = start_block_idx * BLOCK_SIZE;
664
+ const int end_token_idx =
665
+ MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
666
+ const int num_tokens = end_token_idx - start_token_idx;
667
+
668
+ constexpr int THREAD_GROUP_SIZE = MAX(NUM_SIMD_LANES / BLOCK_SIZE, 1);
669
+ constexpr int NUM_THREAD_GROUPS =
670
+ NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
671
+ // divides NUM_THREADS
672
+ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
673
+ constexpr int NUM_TOKENS_PER_THREAD_GROUP =
674
+ DIVIDE_ROUND_UP(BLOCK_SIZE, NUM_SIMD_LANES);
675
+ constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES;
676
+ const int warp_idx = simd_tid;
677
+ const int lane = simd_lid;
678
+
679
+ const int head_idx = threadgroup_position_in_grid.x;
680
+ const int num_heads = threadgroups_per_grid.x;
681
+ const int num_queries_per_kv = num_heads / num_kv_heads;
682
+ const int kv_head_idx = head_idx / num_queries_per_kv;
683
+ const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx];
684
+
685
+ // A vector type to store a part of a key or a query.
686
+ // The vector size is configured in such a way that the threads in a thread
687
+ // group fetch or compute 16 bytes at a time. For example, if the size of a
688
+ // thread group is 4 and the data type is half, then the vector size is 16 /
689
+ // (4 * sizeof(half)) == 2.
690
+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1);
691
+ using K_vec = typename Vec<T, VEC_SIZE>::Type;
692
+ using Q_vec = typename Vec<T, VEC_SIZE>::Type;
693
+
694
+ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
695
+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
696
+
697
+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
698
+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
699
+
700
+ // Load the query to registers.
701
+ // Each thread in a thread group has a different part of the query.
702
+ // For example, if the thread group size is 4, then the first thread in the
703
+ // group has 0, 4, 8, ... th vectors of the query, and the second thread has
704
+ // 1, 5, 9, ... th vectors of the query, and so on.
705
+ const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
706
+ threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
707
+ #pragma unroll
708
+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
709
+ i += NUM_THREAD_GROUPS) {
710
+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
711
+ q_vecs[thread_group_offset][i] =
712
+ *reinterpret_cast<const device Q_vec *>(q_ptr + vec_idx * VEC_SIZE);
713
+ }
714
+ threadgroup_barrier(mem_flags::mem_threadgroup);
715
+
716
+ // Use fp32 on softmax logits for better accuracy
717
+ threadgroup float *logits = reinterpret_cast<threadgroup float *>(shared_mem);
718
+ // Workspace for reduction
719
+ threadgroup float red_smem[2 * NUM_WARPS];
720
+
721
+ // x == THREAD_GROUP_SIZE * VEC_SIZE
722
+ // Each thread group fetches x elements from the key at a time.
723
+ constexpr int x = 16 / sizeof(T);
724
+ float qk_max = -FLT_MAX;
725
+
726
+ // Iterate over the key blocks.
727
+ // Each warp fetches a block of keys for each iteration.
728
+ // Each thread group in a warp fetches a key from the block, and computes
729
+ // dot product with the query.
730
+ const device uint32_t *block_table =
731
+ block_tables + seq_idx * max_num_blocks_per_seq;
732
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
733
+ block_idx += NUM_WARPS) {
734
+ // NOTE: The block number is stored in int32. However, we cast it to int64
735
+ // because int32 can lead to overflow when this variable is multiplied by
736
+ // large numbers (e.g., kv_block_stride).
737
+ const int64_t physical_block_number =
738
+ static_cast<int64_t>(block_table[block_idx]);
739
+
740
+ // Load a key to registers.
741
+ // Each thread in a thread group has a different part of the key.
742
+ // For example, if the thread group size is 4, then the first thread in the
743
+ // group has 0, 4, 8, ... th vectors of the key, and the second thread has
744
+ // 1, 5, 9, ... th vectors of the key, and so on.
745
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
746
+ const int physical_block_offset =
747
+ (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE;
748
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
749
+ K_vec k_vecs[NUM_VECS_PER_THREAD];
750
+
751
+ #pragma unroll
752
+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
753
+ const device T *k_ptr =
754
+ k_cache + physical_block_number * kv_block_stride +
755
+ kv_head_idx * kv_head_stride + physical_block_offset * x;
756
+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
757
+ const int offset1 = (vec_idx * VEC_SIZE) / x;
758
+ const int offset2 = (vec_idx * VEC_SIZE) % x;
759
+ k_vecs[j] = *reinterpret_cast<const device K_vec *>(
760
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
761
+ }
762
+
763
+ // Compute dot product.
764
+ // This includes a reduction across the threads in the same thread group.
765
+ float qk = scale * Qk_dot<T, THREAD_GROUP_SIZE>::dot(
766
+ q_vecs[thread_group_offset], k_vecs);
767
+
768
+ // Apply softcapping
769
+ if (softcapping != 1.0) {
770
+ qk = precise::tanh(qk / softcapping) * softcapping;
771
+ }
772
+
773
+ // Add the ALiBi bias if slopes are given.
774
+ qk +=
775
+ (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
776
+
777
+ if (thread_group_offset == 0) {
778
+ // Store the partial reductions to shared memory.
779
+ // NOTE: It is required to zero out the masked logits.
780
+ const bool mask = token_idx >= context_len;
781
+ logits[token_idx - start_token_idx] = mask ? 0.f : qk;
782
+ // Update the max value.
783
+ qk_max = mask ? qk_max : max(qk_max, qk);
784
+ }
785
+ }
786
+ }
787
+
788
+ // Perform reduction across the threads in the same warp to get the
789
+ // max qk value for each "warp" (not across the thread block yet).
790
+ // The 0-th thread of each thread group already has its max qk value.
791
+ #pragma unroll
792
+ for (int mask = NUM_SIMD_LANES / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
793
+ qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask));
794
+ }
795
+ if (lane == 0) {
796
+ red_smem[warp_idx] = qk_max;
797
+ }
798
+ threadgroup_barrier(mem_flags::mem_threadgroup);
799
+
800
+ // Get the max qk value for the sequence.
801
+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
802
+ #pragma unroll
803
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
804
+ qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask));
805
+ }
806
+ // Broadcast the max qk value to all threads.
807
+ qk_max = simd_shuffle(qk_max, 0);
808
+
809
+ // Get the sum of the exp values.
810
+ float exp_sum = 0.f;
811
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
812
+ float val = exp(logits[i] - qk_max);
813
+ logits[i] = val;
814
+ exp_sum += val;
815
+ }
816
+ exp_sum = block_sum<NUM_WARPS, NUM_SIMD_LANES>(&red_smem[NUM_WARPS], exp_sum,
817
+ simd_tid, simd_lid);
818
+
819
+ // Compute softmax.
820
+ const float inv_sum = divide(1.f, exp_sum + 1e-6f);
821
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
822
+ logits[i] *= inv_sum;
823
+ }
824
+ threadgroup_barrier(mem_flags::mem_threadgroup);
825
+
826
+ // If partitioning is enabled, store the max logit and exp_sum.
827
+ if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) {
828
+ device float *max_logits_ptr =
829
+ max_logits + seq_idx * num_heads * max_num_partitions +
830
+ head_idx * max_num_partitions + partition_idx;
831
+ *max_logits_ptr = qk_max;
832
+ device float *exp_sums_ptr = exp_sums +
833
+ seq_idx * num_heads * max_num_partitions +
834
+ head_idx * max_num_partitions + partition_idx;
835
+ *exp_sums_ptr = exp_sum;
836
+ }
837
+
838
+ // Each thread will fetch 16 bytes from the value cache at a time.
839
+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(T), BLOCK_SIZE);
840
+ using V_vec = typename Vec<T, V_VEC_SIZE>::Type;
841
+ using L_vec = typename Vec<T, V_VEC_SIZE>::Type;
842
+ using Float_L_vec = typename FloatVec<L_vec>::Type;
843
+
844
+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
845
+ constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW;
846
+ constexpr int NUM_ROWS_PER_THREAD =
847
+ DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
848
+
849
+ // NOTE: We use FP32 for the accumulator for better accuracy.
850
+ float accs[NUM_ROWS_PER_THREAD];
851
+ #pragma unroll
852
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
853
+ accs[i] = 0.f;
854
+ }
855
+
856
+ T zero_value = 0;
857
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
858
+ block_idx += NUM_WARPS) {
859
+ // NOTE: The block number is stored in int32. However, we cast it to int64
860
+ // because int32 can lead to overflow when this variable is multiplied by
861
+ // large numbers (e.g., kv_block_stride).
862
+ const int64_t physical_block_number =
863
+ static_cast<int64_t>(block_table[block_idx]);
864
+ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
865
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
866
+ L_vec logits_vec;
867
+ Float_L_vec logits_float_vec = *reinterpret_cast<threadgroup Float_L_vec *>(
868
+ logits + token_idx - start_token_idx);
869
+ from_float(logits_vec, logits_float_vec);
870
+
871
+ const device T *v_ptr = v_cache + physical_block_number * kv_block_stride +
872
+ kv_head_idx * kv_head_stride;
873
+ #pragma unroll
874
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
875
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
876
+ if (row_idx < HEAD_SIZE) {
877
+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
878
+ // NOTE: When v_vec contains the tokens that are out of the context,
879
+ // we should explicitly zero out the values since they may contain NaNs.
880
+ // See
881
+ // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
882
+ V_vec v_vec = *reinterpret_cast<const device V_vec *>(v_ptr + offset);
883
+ if (block_idx == num_context_blocks - 1) {
884
+ thread T *v_vec_ptr = reinterpret_cast<thread T *>(&v_vec);
885
+ #pragma unroll
886
+ for (int j = 0; j < V_VEC_SIZE; j++) {
887
+ v_vec_ptr[j] =
888
+ token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
889
+ }
890
+ }
891
+ accs[i] += dot(logits_vec, v_vec);
892
+ }
893
+ }
894
+ }
895
+
896
+ // Perform reduction within each warp.
897
+ #pragma unroll
898
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
899
+ float acc = accs[i];
900
+ #pragma unroll
901
+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
902
+ acc += simd_shuffle_xor(acc, mask);
903
+ }
904
+ accs[i] = acc;
905
+ }
906
+
907
+ // NOTE: A barrier is required because the shared memory space for logits
908
+ // is reused for the output.
909
+ threadgroup_barrier(mem_flags::mem_threadgroup);
910
+
911
+ // Perform reduction across warps.
912
+ threadgroup float *out_smem =
913
+ reinterpret_cast<threadgroup float *>(shared_mem);
914
+ #pragma unroll
915
+ for (int i = NUM_WARPS; i > 1; i /= 2) {
916
+ int mid = i / 2;
917
+ // Upper warps write to shared memory.
918
+ if (warp_idx >= mid && warp_idx < i) {
919
+ threadgroup float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
920
+ #pragma unroll
921
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
922
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
923
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
924
+ dst[row_idx] = accs[i];
925
+ }
926
+ }
927
+ }
928
+ threadgroup_barrier(mem_flags::mem_threadgroup);
929
+
930
+ // Lower warps update the output.
931
+ if (warp_idx < mid) {
932
+ const threadgroup float *src = &out_smem[warp_idx * HEAD_SIZE];
933
+ #pragma unroll
934
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
935
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
936
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
937
+ accs[i] += src[row_idx];
938
+ }
939
+ }
940
+ }
941
+ threadgroup_barrier(mem_flags::mem_threadgroup);
942
+ }
943
+
944
+ // Write the final output.
945
+ if (warp_idx == 0) {
946
+ device T *out_ptr =
947
+ out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
948
+ head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
949
+ #pragma unroll
950
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
951
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
952
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
953
+ *(out_ptr + row_idx) = T(accs[i]);
954
+ }
955
+ }
956
+ }
957
+ }
958
+
959
+ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
960
+ int PARTITION_SIZE = 0>
961
+ [[kernel]] void paged_attention_v2_reduce(
962
+ device T *out [[buffer(0)]], const device float *exp_sums [[buffer(1)]],
963
+ const device float *max_logits [[buffer(2)]],
964
+ const device T *tmp_out [[buffer(3)]],
965
+ device uint32_t *context_lens [[buffer(4)]],
966
+ const constant int &max_num_partitions [[buffer(5)]],
967
+ threadgroup char *shared_mem [[threadgroup(0)]],
968
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
969
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]],
970
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
971
+ uint3 threads_per_threadgroup [[threads_per_threadgroup]],
972
+ uint simd_tid [[simdgroup_index_in_threadgroup]],
973
+ uint simd_lid [[thread_index_in_simdgroup]]) {
974
+ const int num_heads = threadgroups_per_grid.x;
975
+ const int head_idx = threadgroup_position_in_grid.x;
976
+ const int seq_idx = threadgroup_position_in_grid.y;
977
+ const uint32_t context_len = context_lens[seq_idx];
978
+ const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
979
+ if (num_partitions == 1) {
980
+ // No need to reduce. Only copy tmp_out to out.
981
+ device T *out_ptr =
982
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
983
+ const device T *tmp_out_ptr =
984
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
985
+ head_idx * max_num_partitions * HEAD_SIZE;
986
+ for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
987
+ i += threads_per_threadgroup.x) {
988
+ out_ptr[i] = tmp_out_ptr[i];
989
+ }
990
+ // Terminate the thread block.
991
+ return;
992
+ }
993
+
994
+ constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES;
995
+ const int warp_idx = simd_tid;
996
+ const int lane = simd_lid;
997
+
998
+ // Workspace for reduction.
999
+ threadgroup float red_smem[2 * NUM_WARPS];
1000
+
1001
+ // Load max logits to shared memory.
1002
+ threadgroup float *shared_max_logits =
1003
+ reinterpret_cast<threadgroup float *>(shared_mem);
1004
+ const device float *max_logits_ptr =
1005
+ max_logits + seq_idx * num_heads * max_num_partitions +
1006
+ head_idx * max_num_partitions;
1007
+ float max_logit = -FLT_MAX;
1008
+ for (int i = thread_position_in_threadgroup.x; i < num_partitions;
1009
+ i += threads_per_threadgroup.x) {
1010
+ const float l = max_logits_ptr[i];
1011
+ shared_max_logits[i] = l;
1012
+ max_logit = max(max_logit, l);
1013
+ }
1014
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1015
+
1016
+ // Get the global max logit.
1017
+ // Reduce within the warp.
1018
+ #pragma unroll
1019
+ for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) {
1020
+ max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask));
1021
+ }
1022
+ if (lane == 0) {
1023
+ red_smem[warp_idx] = max_logit;
1024
+ }
1025
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1026
+ // Reduce across warps.
1027
+ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
1028
+ #pragma unroll
1029
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
1030
+ max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask));
1031
+ }
1032
+ // Broadcast the max value to all threads.
1033
+ max_logit = simd_shuffle(max_logit, 0);
1034
+
1035
+ // Load rescaled exp sums to shared memory.
1036
+ threadgroup float *shared_exp_sums = reinterpret_cast<threadgroup float *>(
1037
+ shared_mem + sizeof(float) * num_partitions);
1038
+ const device float *exp_sums_ptr = exp_sums +
1039
+ seq_idx * num_heads * max_num_partitions +
1040
+ head_idx * max_num_partitions;
1041
+ float global_exp_sum = 0.0f;
1042
+ for (int i = thread_position_in_threadgroup.x; i < num_partitions;
1043
+ i += threads_per_threadgroup.x) {
1044
+ float l = shared_max_logits[i];
1045
+ float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit);
1046
+ global_exp_sum += rescaled_exp_sum;
1047
+ shared_exp_sums[i] = rescaled_exp_sum;
1048
+ }
1049
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1050
+ global_exp_sum = block_sum<NUM_WARPS, NUM_SIMD_LANES>(
1051
+ &red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid);
1052
+ const float inv_global_exp_sum = divide(1.0f, global_exp_sum + 1e-6f);
1053
+
1054
+ // Aggregate tmp_out to out.
1055
+ const device T *tmp_out_ptr =
1056
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
1057
+ head_idx * max_num_partitions * HEAD_SIZE;
1058
+ device T *out_ptr =
1059
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
1060
+ #pragma unroll
1061
+ for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
1062
+ i += NUM_THREADS) {
1063
+ float acc = 0.0f;
1064
+ for (int j = 0; j < num_partitions; ++j) {
1065
+ acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
1066
+ inv_global_exp_sum;
1067
+ }
1068
+ out_ptr[i] = T(acc);
1069
+ }
1070
+ }
1071
+
1072
+ #define instantiate_paged_attention_inner( \
1073
+ type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \
1074
+ template \
1075
+ [[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \
1076
+ "_nt" #num_threads "_nsl" #num_simd_lanes \
1077
+ "_ps" #partition_size)]] [[kernel]] void \
1078
+ paged_attention<type, head_size, block_size, num_threads, \
1079
+ num_simd_lanes, partition_size>( \
1080
+ device float *exp_sums \
1081
+ [[buffer(0), function_constant(use_partitioning)]], \
1082
+ device float *max_logits \
1083
+ [[buffer(1), function_constant(use_partitioning)]], \
1084
+ device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \
1085
+ device const type *k_cache [[buffer(4)]], \
1086
+ device const type *v_cache [[buffer(5)]], \
1087
+ const constant int &num_kv_heads [[buffer(6)]], \
1088
+ const constant float &scale [[buffer(7)]], \
1089
+ const constant float &softcapping [[buffer(8)]], \
1090
+ device const uint32_t *block_tables [[buffer(9)]], \
1091
+ device const uint32_t *context_lens [[buffer(10)]], \
1092
+ const constant int &max_num_blocks_per_seq [[buffer(11)]], \
1093
+ device const float *alibi_slopes \
1094
+ [[buffer(12), function_constant(use_alibi)]], \
1095
+ const constant int &q_stride [[buffer(13)]], \
1096
+ const constant int &kv_block_stride [[buffer(14)]], \
1097
+ const constant int &kv_head_stride [[buffer(15)]], \
1098
+ threadgroup char *shared_mem [[threadgroup(0)]], \
1099
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
1100
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
1101
+ uint3 thread_position_in_threadgroup \
1102
+ [[thread_position_in_threadgroup]], \
1103
+ uint simd_tid [[simdgroup_index_in_threadgroup]], \
1104
+ uint simd_lid [[thread_index_in_simdgroup]]);
1105
+
1106
+ #define instantiate_paged_attention_v2_reduce_inner( \
1107
+ type, head_size, num_threads, num_simd_lanes, partition_size) \
1108
+ template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
1109
+ "_nt" #num_threads "_nsl" #num_simd_lanes \
1110
+ "_ps" #partition_size)]] [[kernel]] void \
1111
+ paged_attention_v2_reduce<type, head_size, num_threads, num_simd_lanes, \
1112
+ partition_size>( \
1113
+ device type * out [[buffer(0)]], \
1114
+ const device float *exp_sums [[buffer(1)]], \
1115
+ const device float *max_logits [[buffer(2)]], \
1116
+ const device type *tmp_out [[buffer(3)]], \
1117
+ device uint32_t *context_lens [[buffer(4)]], \
1118
+ const constant int &max_num_partitions [[buffer(5)]], \
1119
+ threadgroup char *shared_mem [[threadgroup(0)]], \
1120
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
1121
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
1122
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
1123
+ uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
1124
+ uint simd_tid [[simdgroup_index_in_threadgroup]], \
1125
+ uint simd_lid [[thread_index_in_simdgroup]]);
1126
+
1127
+ #define instantiate_paged_attention_heads(type, block_size, num_threads, \
1128
+ num_simd_lanes, partition_size) \
1129
+ instantiate_paged_attention_inner(type, 64, block_size, num_threads, \
1130
+ num_simd_lanes, partition_size); \
1131
+ instantiate_paged_attention_inner(type, 80, block_size, num_threads, \
1132
+ num_simd_lanes, partition_size); \
1133
+ instantiate_paged_attention_inner(type, 96, block_size, num_threads, \
1134
+ num_simd_lanes, partition_size); \
1135
+ instantiate_paged_attention_inner(type, 112, block_size, num_threads, \
1136
+ num_simd_lanes, partition_size); \
1137
+ instantiate_paged_attention_inner(type, 128, block_size, num_threads, \
1138
+ num_simd_lanes, partition_size); \
1139
+ instantiate_paged_attention_inner(type, 192, block_size, num_threads, \
1140
+ num_simd_lanes, partition_size); \
1141
+ instantiate_paged_attention_inner(type, 256, block_size, num_threads, \
1142
+ num_simd_lanes, partition_size);
1143
+
1144
+ #define instantiate_paged_attention_v2_reduce_heads( \
1145
+ type, num_threads, num_simd_lanes, partition_size) \
1146
+ instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, \
1147
+ num_simd_lanes, partition_size); \
1148
+ instantiate_paged_attention_v2_reduce_inner(type, 80, num_threads, \
1149
+ num_simd_lanes, partition_size); \
1150
+ instantiate_paged_attention_v2_reduce_inner(type, 96, num_threads, \
1151
+ num_simd_lanes, partition_size); \
1152
+ instantiate_paged_attention_v2_reduce_inner(type, 112, num_threads, \
1153
+ num_simd_lanes, partition_size); \
1154
+ instantiate_paged_attention_v2_reduce_inner(type, 128, num_threads, \
1155
+ num_simd_lanes, partition_size); \
1156
+ instantiate_paged_attention_v2_reduce_inner(type, 192, num_threads, \
1157
+ num_simd_lanes, partition_size); \
1158
+ instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \
1159
+ num_simd_lanes, partition_size);
1160
+
1161
+ #define instantiate_paged_attention_block_size(type, num_threads, \
1162
+ num_simd_lanes, partition_size) \
1163
+ instantiate_paged_attention_heads(type, 8, num_threads, num_simd_lanes, \
1164
+ partition_size); \
1165
+ instantiate_paged_attention_heads(type, 16, num_threads, num_simd_lanes, \
1166
+ partition_size); \
1167
+ instantiate_paged_attention_heads(type, 32, num_threads, num_simd_lanes, \
1168
+ partition_size);
1169
+
1170
+ // TODO: tune num_threads = 256
1171
+ // NOTE: partition_size = 0
1172
+ #define instantiate_paged_attention_v1(type, num_simd_lanes) \
1173
+ instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0);
1174
+
1175
+ // TODO: tune num_threads = 256
1176
+ // NOTE: partition_size = 512
1177
+ #define instantiate_paged_attention_v2(type, num_simd_lanes) \
1178
+ instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \
1179
+ instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
1180
+
1181
+ instantiate_paged_attention_v1(float, 32);
1182
+ instantiate_paged_attention_v1(bfloat16_t, 32);
1183
+ instantiate_paged_attention_v1(half, 32);
1184
+
1185
+ instantiate_paged_attention_v2(float, 32);
1186
+ instantiate_paged_attention_v2(bfloat16_t, 32);
1187
+ instantiate_paged_attention_v2(half, 32);
paged-attention-metal/cache/copy_blocks.metal ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "utils.metal"
2
+ #include <metal_stdlib>
3
+
4
+ using namespace metal;
5
+
6
+ template <typename T>
7
+ [[kernel]] void copy_blocks(device T *key_cache [[buffer(0)]],
8
+ device T *value_cache [[buffer(1)]],
9
+ const device int64_t *block_mapping [[buffer(2)]],
10
+ device const int &numel_per_block,
11
+ uint gid [[thread_position_in_grid]],
12
+ uint tid [[thread_position_in_threadgroup]],
13
+ uint threads_per_threadgroup
14
+ [[threads_per_threadgroup]]) {
15
+ const int pair_idx = gid;
16
+
17
+ int64_t src_block_number = block_mapping[2 * pair_idx];
18
+ int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
19
+
20
+ const int64_t src_block_offset = src_block_number * numel_per_block;
21
+ const int64_t dst_block_offset = dst_block_number * numel_per_block;
22
+
23
+ // Copy key cache blocks
24
+ for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) {
25
+ int64_t src_offset = src_block_offset + i;
26
+ int64_t dst_offset = dst_block_offset + i;
27
+ key_cache[dst_offset] = key_cache[src_offset];
28
+ }
29
+
30
+ // Copy value cache blocks
31
+ for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) {
32
+ int64_t src_offset = src_block_offset + i;
33
+ int64_t dst_offset = dst_block_offset + i;
34
+ value_cache[dst_offset] = value_cache[src_offset];
35
+ }
36
+ }
37
+
38
+ #define instantiate_copy_blocks(type) \
39
+ template [[host_name("copy_blocks_" #type)]] [[kernel]] void \
40
+ copy_blocks<type>(device type * key_cache_ptrs [[buffer(0)]], \
41
+ device type * value_cache_ptrs [[buffer(1)]], \
42
+ const device int64_t *block_mapping [[buffer(2)]], \
43
+ device const int &numel_per_block, \
44
+ uint gid [[thread_position_in_grid]], \
45
+ uint tid [[thread_position_in_threadgroup]], \
46
+ uint threads_per_threadgroup [[threads_per_threadgroup]]);
47
+
48
+ instantiate_copy_blocks(float);
49
+ instantiate_copy_blocks(bfloat16_t);
50
+ instantiate_copy_blocks(half);
paged-attention-metal/cache/reshape_and_cache.metal ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "utils.metal"
2
+ #include <metal_stdlib>
3
+
4
+ using namespace metal;
5
+
6
+ template <typename T>
7
+ [[kernel]] void reshape_and_cache(
8
+ const device T *__restrict__ key
9
+ [[buffer(0)]], // [num_tokens, num_heads, head_size]
10
+ const device T *__restrict__ value
11
+ [[buffer(1)]], // [num_tokens, num_heads, head_size]
12
+ device T *__restrict__ key_cache
13
+ [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x]
14
+ device T *__restrict__ value_cache
15
+ [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size]
16
+ const device int64_t *__restrict__ slot_mapping
17
+ [[buffer(4)]], // [num_tokens]
18
+ device const int &key_stride, device const int &value_stride,
19
+ device const int &num_heads, device const int &head_size,
20
+ device const int &block_size, device const int &x,
21
+ uint gid [[threadgroup_position_in_grid]],
22
+ uint tid [[thread_position_in_threadgroup]],
23
+ uint threads_per_threadgroup [[threads_per_threadgroup]]) {
24
+ const int64_t token_idx = gid;
25
+ const int64_t slot_idx = slot_mapping[token_idx];
26
+ if (slot_idx < 0) {
27
+ // Padding token that should be ignored.
28
+ return;
29
+ }
30
+
31
+ const int64_t block_idx = slot_idx / block_size;
32
+ const int64_t block_offset = slot_idx % block_size;
33
+
34
+ const int n = num_heads * head_size;
35
+ for (int i = tid; i < n; i += threads_per_threadgroup) {
36
+ const int64_t src_key_idx = token_idx * key_stride + i;
37
+ const int64_t src_value_idx = token_idx * value_stride + i;
38
+
39
+ const int head_idx = i / head_size;
40
+ const int head_offset = i % head_size;
41
+ const int x_idx = head_offset / x;
42
+ const int x_offset = head_offset % x;
43
+
44
+ const int64_t tgt_key_idx =
45
+ block_idx * num_heads * (head_size / x) * block_size * x +
46
+ head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
47
+ block_offset * x + x_offset;
48
+ const int64_t tgt_value_idx =
49
+ block_idx * num_heads * head_size * block_size +
50
+ head_idx * head_size * block_size + head_offset * block_size +
51
+ block_offset;
52
+ key_cache[tgt_key_idx] = key[src_key_idx];
53
+ value_cache[tgt_value_idx] = value[src_value_idx];
54
+ }
55
+ }
56
+
57
+ #define instantiate_reshape_and_cache(type) \
58
+ template [[host_name("reshape_and_cache_" #type)]] [[kernel]] void \
59
+ reshape_and_cache<type>( \
60
+ const device type *__restrict__ key [[buffer(0)]], \
61
+ const device type *__restrict__ value [[buffer(1)]], \
62
+ device type *__restrict__ key_cache [[buffer(2)]], \
63
+ device type *__restrict__ value_cache [[buffer(3)]], \
64
+ const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \
65
+ device const int &key_stride, device const int &value_stride, \
66
+ device const int &num_heads, device const int &head_size, \
67
+ device const int &block_size, device const int &x, \
68
+ uint gid [[threadgroup_position_in_grid]], \
69
+ uint tid [[thread_position_in_threadgroup]], \
70
+ uint threads_per_threadgroup [[threads_per_threadgroup]]);
71
+
72
+ instantiate_reshape_and_cache(float);
73
+ instantiate_reshape_and_cache(bfloat16_t);
74
+ instantiate_reshape_and_cache(half);
paged-attention-metal/paged_attention.mm ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+
3
+ #import <Foundation/Foundation.h>
4
+ #import <Metal/Metal.h>
5
+ #include <string>
6
+
7
+ char const *CUSTOM_KERNEL = R"(
8
+ #include <metal_stdlib>
9
+ using namespace metal;
10
+ kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]],
11
+ device float *outC [[buffer(1)]],
12
+ uint index [[thread_position_in_grid]]) {
13
+ // Explicitly write to output
14
+ outC[index] = max(0.0f, inA[index]);
15
+ }
16
+ kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]],
17
+ device half *outC [[buffer(1)]],
18
+ uint index [[thread_position_in_grid]]) {
19
+ // Explicitly write to output
20
+ outC[index] = max(static_cast<half>(0.0), inA[index]);
21
+ }
22
+ )";
23
+
24
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
25
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
26
+ }
27
+
28
+ torch::Tensor &dispatchReluKernel(torch::Tensor const &input,
29
+ torch::Tensor &output) {
30
+ @autoreleasepool {
31
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
32
+ NSError *error = nil;
33
+
34
+ int numThreads = input.numel();
35
+
36
+ id<MTLLibrary> customKernelLibrary = [device
37
+ newLibraryWithSource:[NSString stringWithUTF8String:CUSTOM_KERNEL]
38
+ options:nil
39
+ error:&error];
40
+ TORCH_CHECK(customKernelLibrary,
41
+ "Failed to to create custom kernel library, error: ",
42
+ error.localizedDescription.UTF8String);
43
+
44
+ std::string kernel_name =
45
+ std::string("relu_forward_kernel_") +
46
+ (input.scalar_type() == torch::kFloat ? "float" : "half");
47
+ id<MTLFunction> customReluFunction = [customKernelLibrary
48
+ newFunctionWithName:[NSString
49
+ stringWithUTF8String:kernel_name.c_str()]];
50
+ TORCH_CHECK(customReluFunction,
51
+ "Failed to create function state object for ",
52
+ kernel_name.c_str());
53
+
54
+ id<MTLComputePipelineState> reluPSO =
55
+ [device newComputePipelineStateWithFunction:customReluFunction
56
+ error:&error];
57
+ TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String);
58
+
59
+ id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
60
+ TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
61
+
62
+ dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
63
+
64
+ dispatch_sync(serialQueue, ^() {
65
+ id<MTLComputeCommandEncoder> computeEncoder =
66
+ [commandBuffer computeCommandEncoder];
67
+ TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
68
+
69
+ [computeEncoder setComputePipelineState:reluPSO];
70
+ [computeEncoder setBuffer:getMTLBufferStorage(input)
71
+ offset:input.storage_offset() * input.element_size()
72
+ atIndex:0];
73
+ [computeEncoder setBuffer:getMTLBufferStorage(output)
74
+ offset:output.storage_offset() * output.element_size()
75
+ atIndex:1];
76
+
77
+ MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
78
+
79
+ NSUInteger threadGroupSize = reluPSO.maxTotalThreadsPerThreadgroup;
80
+ if (threadGroupSize > numThreads) {
81
+ threadGroupSize = numThreads;
82
+ }
83
+ MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);
84
+
85
+ [computeEncoder dispatchThreads:gridSize
86
+ threadsPerThreadgroup:threadgroupSize];
87
+
88
+ [computeEncoder endEncoding];
89
+
90
+ torch::mps::commit();
91
+ });
92
+ }
93
+
94
+ return output;
95
+ }
96
+
97
+ void relu(torch::Tensor &out, const torch::Tensor &input) {
98
+ TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
99
+ TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
100
+ TORCH_CHECK(input.scalar_type() == torch::kFloat ||
101
+ input.scalar_type() == torch::kHalf,
102
+ "Unsupported data type: ", input.scalar_type());
103
+
104
+ TORCH_CHECK(input.sizes() == out.sizes(),
105
+ "Tensors must have the same shape. Got input shape: ",
106
+ input.sizes(), " and output shape: ", out.sizes());
107
+
108
+ TORCH_CHECK(input.scalar_type() == out.scalar_type(),
109
+ "Tensors must have the same data type. Got input dtype: ",
110
+ input.scalar_type(), " and output dtype: ", out.scalar_type());
111
+
112
+ TORCH_CHECK(input.device() == out.device(),
113
+ "Tensors must be on the same device. Got input device: ",
114
+ input.device(), " and output device: ", out.device());
115
+
116
+ dispatchReluKernel(input, out);
117
+ }
paged-attention-metal/utils.metal ADDED
File without changes
tests/kernels/__init__.py ADDED
File without changes
tests/kernels/allclose_default.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Reference default values of atol and rtol are from
4
+ # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
5
+ default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
6
+ default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
7
+
8
+
9
+ def get_default_atol(output) -> float:
10
+ return default_atol[output.dtype]
11
+
12
+
13
+ def get_default_rtol(output) -> float:
14
+ return default_rtol[output.dtype]
tests/kernels/conftest.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import paged_attention as ops
4
+ import pytest
5
+ import torch
6
+
7
+
8
+ @pytest.fixture()
9
+ def kv_cache_factory():
10
+ return create_kv_caches_with_random
11
+
12
+
13
+ @pytest.fixture()
14
+ def kv_cache_factory_flashinfer():
15
+ return create_kv_caches_with_random_flash
16
+
17
+
18
+ STR_DTYPE_TO_TORCH_DTYPE = {
19
+ "half": torch.half,
20
+ "bfloat16": torch.bfloat16,
21
+ "float": torch.float,
22
+ "fp8": torch.uint8,
23
+ "fp8_e4m3": torch.uint8,
24
+ "fp8_e5m2": torch.uint8,
25
+ }
26
+
27
+
28
+ def create_kv_caches_with_random(
29
+ num_blocks: int,
30
+ block_size: int,
31
+ num_layers: int,
32
+ num_heads: int,
33
+ head_size: int,
34
+ cache_dtype: Optional[Union[str, torch.dtype]],
35
+ model_dtype: Optional[Union[str, torch.dtype]] = None,
36
+ seed: int = 0,
37
+ device: Optional[str] = "cuda",
38
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
39
+
40
+ if cache_dtype == "fp8" and head_size % 16:
41
+ raise ValueError(
42
+ f"Does not support key cache of type fp8 with head_size {head_size}"
43
+ )
44
+ from paged_attention.platforms import current_platform
45
+
46
+ current_platform.seed_everything(seed)
47
+
48
+ torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
49
+
50
+ scale = head_size**-0.5
51
+ x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
52
+ key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
53
+ key_caches: List[torch.Tensor] = []
54
+ for _ in range(num_layers):
55
+ key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
56
+ if cache_dtype in ["auto", "half", "bfloat16", "float"]:
57
+ key_cache.uniform_(-scale, scale)
58
+ elif cache_dtype == "fp8":
59
+ _generate_random_fp8(key_cache, -scale, scale)
60
+ else:
61
+ raise ValueError(f"Does not support key cache of type {cache_dtype}")
62
+ key_caches.append(key_cache)
63
+
64
+ value_cache_shape = (num_blocks, num_heads, head_size, block_size)
65
+ value_caches: List[torch.Tensor] = []
66
+ for _ in range(num_layers):
67
+ value_cache = torch.empty(
68
+ size=value_cache_shape, dtype=torch_dtype, device=device
69
+ )
70
+ if cache_dtype in ["auto", "half", "bfloat16", "float"]:
71
+ value_cache.uniform_(-scale, scale)
72
+ elif cache_dtype == "fp8":
73
+ _generate_random_fp8(value_cache, -scale, scale)
74
+ else:
75
+ raise ValueError(f"Does not support value cache of type {cache_dtype}")
76
+ value_caches.append(value_cache)
77
+ return key_caches, value_caches
78
+
79
+
80
+ def create_kv_caches_with_random_flash(
81
+ num_blocks: int,
82
+ block_size: int,
83
+ num_layers: int,
84
+ num_heads: int,
85
+ head_size: int,
86
+ cache_dtype: Optional[Union[str, torch.dtype]],
87
+ model_dtype: Optional[Union[str, torch.dtype]] = None,
88
+ seed: int = 0,
89
+ device: Optional[str] = "cuda",
90
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
91
+ from paged_attention.platforms import current_platform
92
+
93
+ current_platform.seed_everything(seed)
94
+
95
+ torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
96
+ key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
97
+ scale = head_size**-0.5
98
+
99
+ key_caches: List[torch.Tensor] = []
100
+ value_caches: List[torch.Tensor] = []
101
+
102
+ for _ in range(num_layers):
103
+ key_value_cache = torch.empty(
104
+ size=key_value_cache_shape, dtype=torch_dtype, device=device
105
+ )
106
+ if cache_dtype in ["auto", "half", "bfloat16", "float"]:
107
+ key_value_cache.uniform_(-scale, scale)
108
+ elif cache_dtype == "fp8":
109
+ _generate_random_fp8(key_value_cache, -scale, scale)
110
+ else:
111
+ raise ValueError(f"Does not support key cache of type {cache_dtype}")
112
+ key_caches.append(key_value_cache[:, 0])
113
+ value_caches.append(key_value_cache[:, 1])
114
+ return key_caches, value_caches
115
+
116
+
117
+ def get_kv_cache_torch_dtype(
118
+ cache_dtype: Optional[Union[str, torch.dtype]],
119
+ model_dtype: Optional[Union[str, torch.dtype]] = None,
120
+ ) -> torch.dtype:
121
+ if isinstance(cache_dtype, str):
122
+ if cache_dtype == "auto":
123
+ if isinstance(model_dtype, str):
124
+ torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
125
+ elif isinstance(model_dtype, torch.dtype):
126
+ torch_dtype = model_dtype
127
+ else:
128
+ raise ValueError(f"Invalid model dtype: {model_dtype}")
129
+ elif cache_dtype in ["half", "bfloat16", "float"]:
130
+ torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
131
+ elif cache_dtype == "fp8":
132
+ torch_dtype = torch.uint8
133
+ else:
134
+ raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
135
+ elif isinstance(cache_dtype, torch.dtype):
136
+ torch_dtype = cache_dtype
137
+ else:
138
+ raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
139
+ return torch_dtype
140
+
141
+
142
+ def _generate_random_fp8(
143
+ tensor: torch.Tensor,
144
+ low: float,
145
+ high: float,
146
+ ) -> None:
147
+ # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
148
+ # it may occur Inf or NaN if we directly use torch.randint
149
+ # to generate random data for fp8 data.
150
+ # For example, s.11111.00 in fp8e5m2 format represents Inf.
151
+ # | E4M3 | E5M2
152
+ # -----|-------------|-------------------
153
+ # Inf | N/A | s.11111.00
154
+ # NaN | s.1111.111 | s.11111.{01,10,11}
155
+ tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
156
+ tensor_tmp.uniform_(low, high)
157
+ ops.convert_fp8(tensor, tensor_tmp)
158
+ del tensor_tmp
tests/kernels/test_attention.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List, Optional, Tuple
3
+
4
+ import paged_attention as ops
5
+ import pytest
6
+ import torch
7
+ from paged_attention.platforms import current_platform
8
+
9
+ from .allclose_default import get_default_atol, get_default_rtol
10
+ from .utils import get_max_shared_memory_bytes, opcheck
11
+
12
+ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
13
+ # This will change depending on the compute capability.
14
+ # - 512 as a buffer
15
+ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
16
+ # There may not be enough gpu memory due to large NUM_BLOCKS.
17
+ # Reduce NUM_BLOCKS when it happens.
18
+ NUM_BLOCKS = 4321 # Arbitrary values for testing
19
+ PARTITION_SIZE = 512
20
+ # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
21
+ DTYPES = (
22
+ [torch.half, torch.bfloat16, torch.float]
23
+ if not current_platform.is_rocm()
24
+ else [torch.half, torch.bfloat16]
25
+ )
26
+ NUM_GEN_SEQS = [7] # Arbitrary values for testing
27
+ NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
28
+ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
29
+
30
+ # This should be sync with get_supported_head_sizes() in
31
+ # vllm.attention.ops.paged_attn.PagedAttention
32
+ HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
33
+
34
+ BLOCK_SIZES = [16, 32]
35
+ USE_ALIBI = [False, True]
36
+ KV_CACHE_DTYPE = ["auto", "fp8"]
37
+ SEEDS = [0]
38
+ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
39
+
40
+
41
+ def ref_masked_attention(
42
+ query: torch.Tensor,
43
+ key: torch.Tensor,
44
+ value: torch.Tensor,
45
+ scale: float,
46
+ attn_mask: Optional[torch.Tensor] = None,
47
+ ) -> torch.Tensor:
48
+ attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
49
+ if attn_mask is not None:
50
+ attn_weights = attn_weights + attn_mask.float()
51
+ attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
52
+ out = torch.einsum("hqk,khd->qhd", attn_weights, value)
53
+ return out
54
+
55
+
56
+ def ref_single_query_cached_kv_attention(
57
+ output: torch.Tensor,
58
+ query: torch.Tensor,
59
+ num_queries_per_kv: int,
60
+ key_cache: torch.Tensor,
61
+ value_cache: torch.Tensor,
62
+ block_tables: torch.Tensor,
63
+ seq_lens: torch.Tensor,
64
+ scale: float,
65
+ alibi_slopes: Optional[torch.Tensor],
66
+ ) -> None:
67
+ num_query_heads = query.shape[1]
68
+ num_kv_heads = value_cache.shape[1]
69
+ head_size = value_cache.shape[2]
70
+ block_size = value_cache.shape[3]
71
+ num_seqs = query.shape[0]
72
+
73
+ block_tables_lst = block_tables.cpu().tolist()
74
+ seq_lens_lst = seq_lens.cpu().tolist()
75
+ for i in range(num_seqs):
76
+ q = query[i].unsqueeze(0)
77
+ block_table = block_tables_lst[i]
78
+ seq_len = int(seq_lens_lst[i])
79
+
80
+ keys_lst: List[torch.Tensor] = []
81
+ values_lst: List[torch.Tensor] = []
82
+ for j in range(seq_len):
83
+ block_number = int(block_table[j // block_size])
84
+ block_offset = j % block_size
85
+
86
+ k = key_cache[block_number, :, :, block_offset, :]
87
+ k = k.reshape(num_kv_heads, head_size)
88
+ keys_lst.append(k)
89
+
90
+ v = value_cache[block_number, :, :, block_offset]
91
+ values_lst.append(v)
92
+ keys = torch.stack(keys_lst, dim=0)
93
+ values = torch.stack(values_lst, dim=0)
94
+ if num_queries_per_kv > 1:
95
+ # Handle MQA and GQA
96
+ keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
97
+ values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
98
+
99
+ alibi_bias = None
100
+ if alibi_slopes is not None:
101
+ # Create the ALiBi bias used in the paged attention kernel.
102
+ position_ids = torch.arange(seq_len).int()
103
+ alibi_bias = (position_ids - seq_len + 1).float()
104
+ alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
105
+
106
+ out = ref_masked_attention(q, keys, values, scale, alibi_bias)
107
+ out = out.view(num_query_heads, head_size)
108
+ output[i].copy_(out, non_blocking=True)
109
+
110
+
111
+ @pytest.mark.parametrize(
112
+ "version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
113
+ )
114
+ @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
115
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
116
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
117
+ @pytest.mark.parametrize("use_alibi", USE_ALIBI)
118
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
119
+ @pytest.mark.parametrize("dtype", DTYPES)
120
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
121
+ @pytest.mark.parametrize("seed", SEEDS)
122
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
123
+ def test_paged_attention(
124
+ kv_cache_factory,
125
+ version: str,
126
+ num_seqs: int,
127
+ num_heads: Tuple[int, int],
128
+ head_size: int,
129
+ use_alibi: bool,
130
+ block_size: int,
131
+ dtype: torch.dtype,
132
+ kv_cache_dtype: str,
133
+ seed: int,
134
+ device: str,
135
+ ) -> None:
136
+ if (kv_cache_dtype == "fp8" and head_size % 16) or (
137
+ version == "rocm" and head_size not in (64, 128)
138
+ ):
139
+ pytest.skip()
140
+
141
+ current_platform.seed_everything(seed)
142
+ torch.set_default_device(device)
143
+ scale = float(1.0 / (head_size**0.5))
144
+ num_query_heads, num_kv_heads = num_heads
145
+ query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
146
+ query.uniform_(-scale, scale)
147
+
148
+ assert num_query_heads % num_kv_heads == 0
149
+ num_queries_per_kv = num_query_heads // num_kv_heads
150
+ alibi_slopes = None
151
+ if use_alibi:
152
+ alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
153
+
154
+ seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
155
+ seq_lens[-1] = MAX_SEQ_LEN
156
+ max_seq_len = max(seq_lens)
157
+ seq_lens = torch.tensor(seq_lens, dtype=torch.int)
158
+
159
+ # Create the block tables.
160
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
161
+ block_tables_lst: List[List[int]] = []
162
+ for _ in range(num_seqs):
163
+ block_table = [
164
+ random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
165
+ ]
166
+ block_tables_lst.append(block_table)
167
+
168
+ block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
169
+
170
+ # Create the KV caches.
171
+ key_caches, value_caches = kv_cache_factory(
172
+ NUM_BLOCKS,
173
+ block_size,
174
+ 1,
175
+ num_kv_heads,
176
+ head_size,
177
+ kv_cache_dtype,
178
+ dtype,
179
+ seed,
180
+ device,
181
+ )
182
+ key_cache, value_cache = key_caches[0], value_caches[0]
183
+
184
+ # Using default kv_scale
185
+ k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
186
+
187
+ # Call the paged attention kernel.
188
+ output = torch.empty_like(query)
189
+ if version == "v1":
190
+ ops.paged_attention_v1(
191
+ output,
192
+ query,
193
+ key_cache,
194
+ value_cache,
195
+ num_kv_heads,
196
+ scale,
197
+ block_tables,
198
+ seq_lens,
199
+ block_size,
200
+ max_seq_len,
201
+ alibi_slopes,
202
+ kv_cache_dtype,
203
+ k_scale,
204
+ v_scale,
205
+ )
206
+
207
+ opcheck(
208
+ ops.ops.paged_attention_v1,
209
+ (
210
+ output,
211
+ query,
212
+ key_cache,
213
+ value_cache,
214
+ num_kv_heads,
215
+ scale,
216
+ block_tables,
217
+ seq_lens,
218
+ block_size,
219
+ max_seq_len,
220
+ alibi_slopes,
221
+ kv_cache_dtype,
222
+ k_scale,
223
+ v_scale,
224
+ 0,
225
+ 0,
226
+ 0,
227
+ 64,
228
+ 0,
229
+ ),
230
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
231
+ )
232
+
233
+ elif version in ("v2", "rocm"):
234
+ num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
235
+ assert PARTITION_SIZE % block_size == 0
236
+ num_seqs, num_heads, head_size = output.shape
237
+ tmp_output = torch.empty(
238
+ size=(num_seqs, num_heads, num_partitions, head_size),
239
+ dtype=output.dtype,
240
+ )
241
+ exp_sums = torch.empty(
242
+ size=(num_seqs, num_heads, num_partitions),
243
+ dtype=torch.float32,
244
+ )
245
+ max_logits = torch.empty_like(exp_sums)
246
+ if version == "v2":
247
+ ops.paged_attention_v2(
248
+ output,
249
+ exp_sums,
250
+ max_logits,
251
+ tmp_output,
252
+ query,
253
+ key_cache,
254
+ value_cache,
255
+ num_kv_heads,
256
+ scale,
257
+ block_tables,
258
+ seq_lens,
259
+ block_size,
260
+ max_seq_len,
261
+ alibi_slopes,
262
+ kv_cache_dtype,
263
+ k_scale,
264
+ v_scale,
265
+ )
266
+
267
+ opcheck(
268
+ ops.ops.paged_attention_v2,
269
+ (
270
+ output,
271
+ exp_sums,
272
+ max_logits,
273
+ tmp_output,
274
+ query,
275
+ key_cache,
276
+ value_cache,
277
+ num_kv_heads,
278
+ scale,
279
+ block_tables,
280
+ seq_lens,
281
+ block_size,
282
+ max_seq_len,
283
+ alibi_slopes,
284
+ kv_cache_dtype,
285
+ k_scale,
286
+ v_scale,
287
+ 0,
288
+ 0,
289
+ 0,
290
+ 64,
291
+ 0,
292
+ ),
293
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
294
+ )
295
+
296
+ else:
297
+ ops.paged_attention_rocm(
298
+ output,
299
+ exp_sums,
300
+ max_logits,
301
+ tmp_output,
302
+ query,
303
+ key_cache,
304
+ value_cache,
305
+ num_kv_heads,
306
+ scale,
307
+ block_tables,
308
+ seq_lens,
309
+ block_size,
310
+ max_seq_len,
311
+ alibi_slopes,
312
+ kv_cache_dtype,
313
+ k_scale,
314
+ v_scale,
315
+ )
316
+
317
+ opcheck(
318
+ torch.ops._rocm_C.paged_attention,
319
+ (
320
+ output,
321
+ exp_sums,
322
+ max_logits,
323
+ tmp_output,
324
+ query,
325
+ key_cache,
326
+ value_cache,
327
+ num_kv_heads,
328
+ scale,
329
+ block_tables,
330
+ seq_lens,
331
+ block_size,
332
+ max_seq_len,
333
+ alibi_slopes,
334
+ kv_cache_dtype,
335
+ k_scale,
336
+ v_scale,
337
+ ),
338
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
339
+ )
340
+
341
+ else:
342
+ raise AssertionError(f"Unknown version: {version}")
343
+
344
+ # Run the reference implementation.
345
+ if kv_cache_dtype == "fp8":
346
+ # Convert cache data back to dtype.
347
+ x = 16 // torch.tensor([], dtype=dtype).element_size()
348
+ key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
349
+ dequantized_key_cache = torch.empty(
350
+ size=key_cache_shape, dtype=dtype, device=device
351
+ )
352
+ ops.convert_fp8(dequantized_key_cache, key_cache)
353
+ key_cache = dequantized_key_cache
354
+
355
+ value_cache_shape = value_cache.shape
356
+ dequantized_value_cache = torch.empty(
357
+ size=value_cache_shape, dtype=dtype, device=device
358
+ )
359
+ ops.convert_fp8(dequantized_value_cache, value_cache)
360
+ value_cache = dequantized_value_cache
361
+
362
+ ref_output = torch.empty_like(query)
363
+ ref_single_query_cached_kv_attention(
364
+ ref_output,
365
+ query,
366
+ num_queries_per_kv,
367
+ key_cache,
368
+ value_cache,
369
+ block_tables,
370
+ seq_lens,
371
+ scale,
372
+ alibi_slopes,
373
+ )
374
+
375
+ # NOTE(woosuk): Due to the kernel-level differences in the two
376
+ # implementations, there is a small numerical difference in the two
377
+ # outputs. Thus, we use a relaxed tolerance for the test.
378
+ atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
379
+ rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
380
+
381
+ # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
382
+ # so we use a relaxed tolerance for the test.
383
+ atol, rtol = 1e-3, 1e-5
384
+ if kv_cache_dtype == "fp8":
385
+ atol, rtol = 1e-2, 1e-5
386
+ torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
387
+
388
+
389
+ def ref_multi_query_kv_attention(
390
+ cu_seq_lens: List[int],
391
+ query: torch.Tensor,
392
+ key: torch.Tensor,
393
+ value: torch.Tensor,
394
+ scale: float,
395
+ dtype: torch.dtype,
396
+ ) -> torch.Tensor:
397
+ num_seqs = len(cu_seq_lens) - 1
398
+ ref_outputs: List[torch.Tensor] = []
399
+ for i in range(num_seqs):
400
+ start_idx = cu_seq_lens[i]
401
+ end_idx = cu_seq_lens[i + 1]
402
+ seq_len = end_idx - start_idx
403
+
404
+ # Create attention mask.
405
+ attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
406
+ attn_mask = attn_mask * torch.finfo(dtype).min
407
+ attn_mask = attn_mask.to(dtype=dtype)
408
+
409
+ ref_output = ref_masked_attention(
410
+ query[start_idx:end_idx],
411
+ key[start_idx:end_idx],
412
+ value[start_idx:end_idx],
413
+ scale,
414
+ attn_mask=attn_mask,
415
+ )
416
+ ref_outputs.append(ref_output)
417
+
418
+ return torch.cat(ref_outputs, dim=0)
tests/kernels/test_cache.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List, Tuple
3
+
4
+ import paged_attention as ops
5
+ import pytest
6
+ import torch
7
+ from paged_attention.platforms import current_platform
8
+
9
+ from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
+
11
+ COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
12
+ DTYPES = [torch.half, torch.bfloat16, torch.float]
13
+ NUM_TOKENS = [42] # Arbitrary values for testing
14
+ NUM_LAYERS = [1] # Arbitrary values for testing
15
+ NUM_HEADS = [8] # Arbitrary values for testing
16
+ HEAD_SIZES = [64, 80, 120, 256]
17
+ BLOCK_SIZES = [8, 16, 32]
18
+
19
+ # Arbitrary values for testing
20
+ # don't make it too large. e.g. [1024, 36000] will OOM
21
+ NUM_BLOCKS = [1024, 10000]
22
+
23
+ NUM_MAPPINGS = [256] # Arbitrary values for testing
24
+ SEEDS = [0]
25
+ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
26
+
27
+ # We assume fp8 is always enabled for testing.
28
+ KV_CACHE_DTYPE = ["auto", "fp8"]
29
+
30
+
31
+ @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
32
+ @pytest.mark.parametrize("num_layers", NUM_LAYERS)
33
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
34
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
35
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
36
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
37
+ @pytest.mark.parametrize("dtype", DTYPES)
38
+ @pytest.mark.parametrize("seed", SEEDS)
39
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
40
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
41
+ @torch.inference_mode()
42
+ def test_copy_blocks(
43
+ kv_cache_factory,
44
+ num_mappings: int,
45
+ num_layers: int,
46
+ num_heads: int,
47
+ head_size: int,
48
+ block_size: int,
49
+ num_blocks: int,
50
+ dtype: torch.dtype,
51
+ seed: int,
52
+ kv_cache_dtype: str,
53
+ device: str,
54
+ ) -> None:
55
+ if kv_cache_dtype == "fp8" and head_size % 16:
56
+ pytest.skip()
57
+ current_platform.seed_everything(seed)
58
+ torch.set_default_device(device)
59
+ # Generate random block mappings where each source block is mapped to two
60
+ # destination blocks.
61
+ assert 2 * num_mappings <= num_blocks
62
+ src_blocks = random.sample(range(num_blocks), num_mappings)
63
+ remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
64
+ dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
65
+ block_mapping: List[Tuple[int, int]] = []
66
+ for i in range(num_mappings):
67
+ src = src_blocks[i]
68
+ dst1 = dst_blocks[2 * i]
69
+ dst2 = dst_blocks[2 * i + 1]
70
+ block_mapping.append((src, dst1))
71
+ block_mapping.append((src, dst2))
72
+
73
+ # Create the KV caches.
74
+ key_caches, value_caches = kv_cache_factory(
75
+ num_blocks,
76
+ block_size,
77
+ num_layers,
78
+ num_heads,
79
+ head_size,
80
+ kv_cache_dtype,
81
+ dtype,
82
+ seed,
83
+ device,
84
+ )
85
+
86
+ # Clone the KV caches.
87
+ cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
88
+ cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
89
+
90
+ # Call the copy blocks kernel.
91
+ block_mapping_tensor = torch.tensor(
92
+ block_mapping, dtype=torch.int64, device=device
93
+ ).view(-1, 2)
94
+
95
+ opcheck(
96
+ ops.ops.copy_blocks,
97
+ (key_caches, value_caches, block_mapping_tensor),
98
+ test_utils=DEFAULT_OPCHECK_TEST_UTILS,
99
+ cond=(head_size == HEAD_SIZES[0]),
100
+ )
101
+ ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
102
+
103
+ # Run the reference implementation.
104
+ for src, dst in block_mapping:
105
+ for cloned_key_cache in cloned_key_caches:
106
+ cloned_key_cache[dst].copy_(cloned_key_cache[src])
107
+ for cloned_value_cache in cloned_value_caches:
108
+ cloned_value_cache[dst].copy_(cloned_value_cache[src])
109
+
110
+ # Compare the results.
111
+ for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
112
+ torch.testing.assert_close(key_cache, cloned_key_cache)
113
+ for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
114
+ torch.testing.assert_close(value_cache, cloned_value_cache)
115
+
116
+
117
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
118
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
119
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
120
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
121
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
122
+ @pytest.mark.parametrize("dtype", DTYPES)
123
+ @pytest.mark.parametrize("seed", SEEDS)
124
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
125
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
126
+ @torch.inference_mode()
127
+ def test_reshape_and_cache(
128
+ kv_cache_factory,
129
+ num_tokens: int,
130
+ num_heads: int,
131
+ head_size: int,
132
+ block_size: int,
133
+ num_blocks: int,
134
+ dtype: torch.dtype,
135
+ seed: int,
136
+ device: str,
137
+ kv_cache_dtype: str,
138
+ ) -> None:
139
+ if kv_cache_dtype == "fp8" and head_size % 16:
140
+ pytest.skip()
141
+ current_platform.seed_everything(seed)
142
+ torch.set_default_device(device)
143
+ # Create a random slot mapping.
144
+ num_slots = block_size * num_blocks
145
+ slot_mapping_lst = random.sample(range(num_slots), num_tokens)
146
+ slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
147
+
148
+ qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
149
+ _, key, value = qkv.unbind(dim=1)
150
+
151
+ # Create the KV caches.
152
+ key_caches, value_caches = kv_cache_factory(
153
+ num_blocks,
154
+ block_size,
155
+ 1,
156
+ num_heads,
157
+ head_size,
158
+ kv_cache_dtype,
159
+ dtype,
160
+ seed,
161
+ device,
162
+ )
163
+ key_cache, value_cache = key_caches[0], value_caches[0]
164
+
165
+ # Clone the KV caches.
166
+ if kv_cache_dtype == "fp8":
167
+ cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
168
+ ops.convert_fp8(cloned_key_cache, key_cache)
169
+ cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
170
+ ops.convert_fp8(cloned_value_cache, value_cache)
171
+ else:
172
+ cloned_key_cache = key_cache.clone()
173
+ cloned_value_cache = value_cache.clone()
174
+
175
+ # Using default kv_scale
176
+ k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
177
+
178
+ # Call the reshape_and_cache kernel.
179
+ opcheck(
180
+ ops.ops.reshape_and_cache,
181
+ (
182
+ key,
183
+ value,
184
+ key_cache,
185
+ value_cache,
186
+ slot_mapping,
187
+ kv_cache_dtype,
188
+ k_scale,
189
+ v_scale,
190
+ ),
191
+ cond=(head_size == HEAD_SIZES[0]),
192
+ )
193
+ ops.reshape_and_cache(
194
+ key,
195
+ value,
196
+ key_cache,
197
+ value_cache,
198
+ slot_mapping,
199
+ kv_cache_dtype,
200
+ k_scale,
201
+ v_scale,
202
+ )
203
+
204
+ if kv_cache_dtype == "fp8":
205
+ result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
206
+ ops.convert_fp8(result_key_cache, key_cache)
207
+ result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
208
+ ops.convert_fp8(result_value_cache, value_cache)
209
+
210
+ # Run the reference implementation.
211
+ reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
212
+ block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
213
+ block_indicies_lst = block_indicies.cpu().tolist()
214
+ block_offsets = slot_mapping % block_size
215
+ block_offsets_lst = block_offsets.cpu().tolist()
216
+ for i in range(num_tokens):
217
+ block_idx = block_indicies_lst[i]
218
+ block_offset = block_offsets_lst[i]
219
+ cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
220
+ cloned_value_cache[block_idx, :, :, block_offset] = value[i]
221
+
222
+ if kv_cache_dtype == "fp8":
223
+ torch.testing.assert_close(
224
+ result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
225
+ )
226
+ torch.testing.assert_close(
227
+ result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
228
+ )
229
+ else:
230
+ torch.testing.assert_close(key_cache, cloned_key_cache)
231
+ torch.testing.assert_close(value_cache, cloned_value_cache)
232
+
233
+
234
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
235
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
236
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
237
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
238
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
239
+ @pytest.mark.parametrize("dtype", DTYPES)
240
+ @pytest.mark.parametrize("seed", SEEDS)
241
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
242
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
243
+ @torch.inference_mode()
244
+ def test_reshape_and_cache_flash(
245
+ kv_cache_factory_flashinfer,
246
+ num_tokens: int,
247
+ num_heads: int,
248
+ head_size: int,
249
+ block_size: int,
250
+ num_blocks: int,
251
+ dtype: torch.dtype,
252
+ seed: int,
253
+ device: str,
254
+ kv_cache_dtype: str,
255
+ ) -> None:
256
+ current_platform.seed_everything(seed)
257
+ torch.set_default_device(device)
258
+
259
+ # Create a random slot mapping.
260
+ num_slots = block_size * num_blocks
261
+ slot_mapping_lst = random.sample(range(num_slots), num_tokens)
262
+ slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
263
+
264
+ qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
265
+ _, key, value = qkv.unbind(dim=1)
266
+
267
+ # Create the KV caches.
268
+ key_caches, value_caches = kv_cache_factory_flashinfer(
269
+ num_blocks,
270
+ block_size,
271
+ 1,
272
+ num_heads,
273
+ head_size,
274
+ kv_cache_dtype,
275
+ dtype,
276
+ device=device,
277
+ )
278
+ key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous()
279
+ del key_caches
280
+ del value_caches
281
+
282
+ k_scale = (key.amax() / 256.0).to(torch.float32)
283
+ v_scale = (value.amax() / 256.0).to(torch.float32)
284
+
285
+ # Clone the KV caches.
286
+ if kv_cache_dtype == "fp8":
287
+ cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
288
+ ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
289
+ cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
290
+ ops.convert_fp8(cloned_value_cache, value_cache, v_scale, kv_cache_dtype)
291
+ else:
292
+ cloned_key_cache = key_cache.clone()
293
+ cloned_value_cache = value_cache.clone()
294
+
295
+ # Call the reshape_and_cache kernel.
296
+ opcheck(
297
+ ops.ops.reshape_and_cache_flash,
298
+ (
299
+ key,
300
+ value,
301
+ key_cache,
302
+ value_cache,
303
+ slot_mapping,
304
+ kv_cache_dtype,
305
+ k_scale,
306
+ v_scale,
307
+ ),
308
+ cond=(head_size == HEAD_SIZES[0]),
309
+ )
310
+ ops.reshape_and_cache_flash(
311
+ key,
312
+ value,
313
+ key_cache,
314
+ value_cache,
315
+ slot_mapping,
316
+ kv_cache_dtype,
317
+ k_scale,
318
+ v_scale,
319
+ )
320
+
321
+ if kv_cache_dtype == "fp8":
322
+ result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
323
+ ops.convert_fp8(
324
+ result_key_cache, key_cache, k_scale.item(), kv_dtype=kv_cache_dtype
325
+ )
326
+ result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
327
+ ops.convert_fp8(
328
+ result_value_cache, value_cache, v_scale.item(), kv_dtype=kv_cache_dtype
329
+ )
330
+
331
+ # Run the reference implementation.
332
+ block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
333
+ block_indicies_lst = block_indicies.cpu().tolist()
334
+ block_offsets = slot_mapping % block_size
335
+ block_offsets_lst = block_offsets.cpu().tolist()
336
+ for i in range(num_tokens):
337
+ block_idx = block_indicies_lst[i]
338
+ block_offset = block_offsets_lst[i]
339
+ cloned_key_cache[block_idx, block_offset, :, :] = key[i]
340
+ cloned_value_cache[block_idx, block_offset, :, :] = value[i]
341
+
342
+ if kv_cache_dtype == "fp8":
343
+ torch.testing.assert_close(
344
+ result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
345
+ )
346
+ torch.testing.assert_close(
347
+ result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
348
+ )
349
+ else:
350
+ torch.testing.assert_close(key_cache, cloned_key_cache)
351
+ torch.testing.assert_close(value_cache, cloned_value_cache)
352
+
353
+
354
+ @pytest.mark.parametrize("direction", COPYING_DIRECTION)
355
+ @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
356
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
357
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
358
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
359
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
360
+ @pytest.mark.parametrize("dtype", DTYPES)
361
+ @pytest.mark.parametrize("seed", SEEDS)
362
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
363
+ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
364
+ @torch.inference_mode()
365
+ def test_swap_blocks(
366
+ kv_cache_factory,
367
+ direction: Tuple[str, str],
368
+ num_mappings: int,
369
+ num_heads: int,
370
+ head_size: int,
371
+ block_size: int,
372
+ num_blocks: int,
373
+ dtype: torch.dtype,
374
+ seed: int,
375
+ device: str,
376
+ kv_cache_dtype: str,
377
+ ) -> None:
378
+ if kv_cache_dtype == "fp8" and "cpu" in direction:
379
+ pytest.skip()
380
+ if kv_cache_dtype == "fp8" and head_size % 16:
381
+ pytest.skip()
382
+
383
+ current_platform.seed_everything(seed)
384
+
385
+ src_device = device if direction[0] == "cuda" else "cpu"
386
+ dst_device = device if direction[1] == "cuda" else "cpu"
387
+
388
+ src_blocks = random.sample(range(num_blocks), num_mappings)
389
+ # For the same device, mapping must not overlap
390
+ if src_device == dst_device:
391
+ remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
392
+ dst_blocks = random.sample(remaining_blocks, num_mappings)
393
+ else:
394
+ dst_blocks = random.sample(range(num_blocks), num_mappings)
395
+
396
+ block_mapping = list(zip(src_blocks, dst_blocks))
397
+ block_mapping_tensor = torch.tensor(
398
+ block_mapping, dtype=torch.int64, device="cpu"
399
+ ).view(-1, 2)
400
+
401
+ # Create the KV caches on the first device.
402
+ src_key_caches, src_value_caches = kv_cache_factory(
403
+ num_blocks,
404
+ block_size,
405
+ 1,
406
+ num_heads,
407
+ head_size,
408
+ kv_cache_dtype,
409
+ dtype,
410
+ seed,
411
+ src_device,
412
+ )
413
+
414
+ # Create the KV caches on the second device.
415
+ dist_key_caches, dist_value_caches = kv_cache_factory(
416
+ num_blocks,
417
+ block_size,
418
+ 1,
419
+ num_heads,
420
+ head_size,
421
+ kv_cache_dtype,
422
+ dtype,
423
+ seed,
424
+ dst_device,
425
+ )
426
+
427
+ src_key_caches_clone = src_key_caches[0].clone()
428
+ src_value_caches_clone = src_value_caches[0].clone()
429
+
430
+ # Call the swap_blocks kernel.
431
+ do_opcheck = head_size == HEAD_SIZES[0]
432
+ opcheck(
433
+ ops.ops.swap_blocks,
434
+ (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
435
+ cond=do_opcheck,
436
+ )
437
+ opcheck(
438
+ ops.ops.swap_blocks,
439
+ (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
440
+ cond=do_opcheck,
441
+ )
442
+
443
+ ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
444
+ ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
445
+
446
+ for src, dst in block_mapping:
447
+ torch.testing.assert_close(
448
+ src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
449
+ )
450
+ torch.testing.assert_close(
451
+ src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
452
+ )
453
+
454
+
455
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
456
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
457
+ @pytest.mark.parametrize("block_size", BLOCK_SIZES)
458
+ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
459
+ @pytest.mark.parametrize("dtype", DTYPES)
460
+ @pytest.mark.parametrize("seed", SEEDS)
461
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
462
+ @torch.inference_mode()
463
+ def test_fp8_e4m3_conversion(
464
+ num_heads: int,
465
+ head_size: int,
466
+ block_size: int,
467
+ num_blocks: int,
468
+ dtype: torch.dtype,
469
+ seed: int,
470
+ device: str,
471
+ ) -> None:
472
+ current_platform.seed_everything(seed)
473
+
474
+ low = -224.0
475
+ high = 224.0
476
+ shape = (num_blocks, num_heads, head_size, block_size)
477
+ cache = torch.empty(shape, dtype=dtype, device=device)
478
+ cache.uniform_(low, high)
479
+
480
+ cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
481
+ ops.convert_fp8(cache_fp8, cache)
482
+
483
+ converted_cache = torch.empty_like(cache)
484
+ ops.convert_fp8(converted_cache, cache_fp8)
485
+
486
+ torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
tests/kernels/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kernel test utils"""
2
+
3
+ import itertools
4
+ import random
5
+ import unittest
6
+ from functools import lru_cache
7
+ from numbers import Number
8
+ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
9
+
10
+ import pytest
11
+ import torch
12
+ from torch._prims_common import TensorLikeType
13
+
14
+ # For now, disable "test_aot_dispatch_dynamic" since there are some
15
+ # bugs related to this test in PyTorch 2.4.
16
+ DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
17
+ "test_schema",
18
+ "test_autograd_registration",
19
+ "test_faketensor",
20
+ )
21
+
22
+ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
23
+ "test_schema",
24
+ "test_autograd_registration",
25
+ "test_faketensor",
26
+ "test_aot_dispatch_dynamic",
27
+ )
28
+
29
+
30
+ # Copied/modified from torch._refs.__init__.py
31
+ def fp8_allclose(
32
+ a: TensorLikeType,
33
+ b: TensorLikeType,
34
+ rtol: float = 1e-05,
35
+ atol: float = 1e-08,
36
+ equal_nan: bool = False,
37
+ ) -> bool:
38
+ """
39
+ Reference implementation of torch.allclose
40
+ """
41
+ torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
42
+
43
+ return bool(
44
+ torch.all(
45
+ torch.isclose(
46
+ a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
47
+ )
48
+ ).item()
49
+ )
50
+
51
+
52
+ def compute_max_diff(output, output_ref):
53
+ return torch.mean(torch.abs(output - output_ref)) / torch.mean(
54
+ torch.abs(output_ref)
55
+ )
56
+
57
+
58
+ # A special version of op check that has a restricted default set of test_utils
59
+ # and a patched version of allclose that supports fp8 types.
60
+ def opcheck(
61
+ op: Union[
62
+ torch._ops.OpOverload,
63
+ torch._ops.OpOverloadPacket,
64
+ torch._library.custom_ops.CustomOpDef,
65
+ ],
66
+ args: Tuple[Any, ...],
67
+ kwargs: Optional[Dict[str, Any]] = None,
68
+ *,
69
+ test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
70
+ raise_exception: bool = True,
71
+ cond: bool = True
72
+ ) -> Dict[str, str]:
73
+ with unittest.mock.patch("torch.allclose", new=fp8_allclose):
74
+ return (
75
+ torch.library.opcheck(
76
+ op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
77
+ )
78
+ if cond
79
+ else {}
80
+ )
81
+
82
+
83
+ @lru_cache(maxsize=None)
84
+ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
85
+ """Returns the maximum shared memory per thread block in bytes."""
86
+ from paged_attention import ops
87
+
88
+ max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
89
+ # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
90
+ # will fail
91
+ assert max_shared_mem > 0, "max_shared_mem can not be zero"
92
+ return int(max_shared_mem)