EricB HF Staff commited on
Commit
0cd49c5
·
1 Parent(s): 957a885

Add fp8 support

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.so filter=lfs diff=lfs merge=lfs -text
37
+ *.metallib filter=lfs diff=lfs merge=lfs -text
build.toml CHANGED
@@ -13,6 +13,8 @@ src = [
13
  "paged-attention-metal/attention/paged_attention.metal",
14
  "paged-attention-metal/cache/copy_blocks.metal",
15
  "paged-attention-metal/cache/reshape_and_cache.metal",
 
 
16
  "paged-attention-metal/utils.metal",
17
  "paged-attention-metal/paged_attention.mm",
18
  "paged-attention-metal/cache.mm",
 
13
  "paged-attention-metal/attention/paged_attention.metal",
14
  "paged-attention-metal/cache/copy_blocks.metal",
15
  "paged-attention-metal/cache/reshape_and_cache.metal",
16
+ "paged-attention-metal/convert_fp8.metal",
17
+ "paged-attention-metal/float8.metal",
18
  "paged-attention-metal/utils.metal",
19
  "paged-attention-metal/paged_attention.mm",
20
  "paged-attention-metal/cache.mm",
paged-attention-metal/attention/paged_attention.metal CHANGED
@@ -1,6 +1,7 @@
1
  // Updated from MLX commit has f70764a
2
 
3
  #include "../utils.metal"
 
4
  #include <metal_simdgroup>
5
  #include <metal_stdlib>
6
 
@@ -529,6 +530,154 @@ inline void from_float(thread Half8_ &dst, Float8_ src) {
529
  dst.y = y;
530
  }
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  // ========================================== Dot product utilities
533
 
534
  // TODO(EricLBuehler): optimize with vectorization
@@ -602,8 +751,9 @@ inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid,
602
 
603
  constant bool use_partitioning [[function_constant(10)]];
604
  constant bool use_alibi [[function_constant(20)]];
 
605
 
606
- template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
607
  int NUM_SIMD_LANES, int PARTITION_SIZE = 0>
608
  [[kernel]] void paged_attention(
609
  device float *exp_sums
@@ -615,22 +765,26 @@ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
615
  device T *out
616
  [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size]
617
  device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size]
618
- device const T *k_cache
619
  [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x]
620
- device const T *v_cache
621
  [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size]
622
- const constant int &num_kv_heads [[buffer(6)]], // [num_heads]
623
- const constant float &scale [[buffer(7)]],
624
- const constant float &softcapping [[buffer(8)]],
 
 
 
 
625
  device const uint32_t *block_tables
626
- [[buffer(9)]], // [num_seqs, max_num_blocks_per_seq]
627
- device const uint32_t *context_lens [[buffer(10)]], // [num_seqs]
628
- const constant int &max_num_blocks_per_seq [[buffer(11)]],
629
  device const float *alibi_slopes
630
- [[buffer(12)]], // [num_heads] - only used when use_alibi
631
- const constant int &q_stride [[buffer(13)]],
632
- const constant int &kv_block_stride [[buffer(14)]],
633
- const constant int &kv_head_stride [[buffer(15)]],
634
  threadgroup char *shared_mem [[threadgroup(0)]],
635
  uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
636
  uint3 threadgroups_per_grid [[threadgroups_per_grid]],
@@ -690,6 +844,7 @@ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
690
  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1);
691
  using K_vec = typename Vec<T, VEC_SIZE>::Type;
692
  using Q_vec = typename Vec<T, VEC_SIZE>::Type;
 
693
 
694
  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
695
  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
@@ -720,7 +875,7 @@ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
720
 
721
  // x == THREAD_GROUP_SIZE * VEC_SIZE
722
  // Each thread group fetches x elements from the key at a time.
723
- constexpr int x = 16 / sizeof(T);
724
  float qk_max = -FLT_MAX;
725
 
726
  // Iterate over the key blocks.
@@ -750,14 +905,23 @@ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
750
 
751
  #pragma unroll
752
  for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
753
- const device T *k_ptr =
754
  k_cache + physical_block_number * kv_block_stride +
755
  kv_head_idx * kv_head_stride + physical_block_offset * x;
756
  const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
757
  const int offset1 = (vec_idx * VEC_SIZE) / x;
758
  const int offset2 = (vec_idx * VEC_SIZE) % x;
759
- k_vecs[j] = *reinterpret_cast<const device K_vec *>(
760
- k_ptr + offset1 * BLOCK_SIZE * x + offset2);
 
 
 
 
 
 
 
 
 
761
  }
762
 
763
  // Compute dot product.
@@ -844,6 +1008,7 @@ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
844
  using V_vec = typename Vec<T, V_VEC_SIZE>::Type;
845
  using L_vec = typename Vec<T, V_VEC_SIZE>::Type;
846
  using Float_L_vec = typename FloatVec<L_vec>::Type;
 
847
 
848
  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
849
  constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW;
@@ -872,8 +1037,8 @@ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
872
  logits + token_idx - start_token_idx);
873
  from_float(logits_vec, logits_float_vec);
874
 
875
- const device T *v_ptr = v_cache + physical_block_number * kv_block_stride +
876
- kv_head_idx * kv_head_stride;
877
  #pragma unroll
878
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
879
  const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@@ -883,7 +1048,18 @@ template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
883
  // we should explicitly zero out the values since they may contain NaNs.
884
  // See
885
  // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
886
- V_vec v_vec = *reinterpret_cast<const device V_vec *>(v_ptr + offset);
 
 
 
 
 
 
 
 
 
 
 
887
  if (block_idx == num_context_blocks - 1) {
888
  thread T *v_vec_ptr = reinterpret_cast<thread T *>(&v_vec);
889
  #pragma unroll
@@ -1073,36 +1249,38 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
1073
  }
1074
  }
1075
 
1076
- #define instantiate_paged_attention_inner( \
1077
- type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \
1078
- template \
1079
- [[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \
1080
- "_nt" #num_threads "_nsl" #num_simd_lanes \
1081
- "_ps" #partition_size)]] [[kernel]] void \
1082
- paged_attention<type, head_size, block_size, num_threads, \
1083
- num_simd_lanes, partition_size>( \
1084
- device float *exp_sums [[buffer(0)]], \
1085
- device float *max_logits [[buffer(1)]], \
1086
- device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \
1087
- device const type *k_cache [[buffer(4)]], \
1088
- device const type *v_cache [[buffer(5)]], \
1089
- const constant int &num_kv_heads [[buffer(6)]], \
1090
- const constant float &scale [[buffer(7)]], \
1091
- const constant float &softcapping [[buffer(8)]], \
1092
- device const uint32_t *block_tables [[buffer(9)]], \
1093
- device const uint32_t *context_lens [[buffer(10)]], \
1094
- const constant int &max_num_blocks_per_seq [[buffer(11)]], \
1095
- device const float *alibi_slopes [[buffer(12)]], \
1096
- const constant int &q_stride [[buffer(13)]], \
1097
- const constant int &kv_block_stride [[buffer(14)]], \
1098
- const constant int &kv_head_stride [[buffer(15)]], \
1099
- threadgroup char *shared_mem [[threadgroup(0)]], \
1100
- uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
1101
- uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
1102
- uint3 thread_position_in_threadgroup \
1103
- [[thread_position_in_threadgroup]], \
1104
- uint simd_tid [[simdgroup_index_in_threadgroup]], \
1105
- uint simd_lid [[thread_index_in_simdgroup]]);
 
 
1106
 
1107
  #define instantiate_paged_attention_v2_reduce_inner( \
1108
  type, head_size, num_threads, num_simd_lanes, partition_size) \
@@ -1125,26 +1303,35 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
1125
  uint simd_tid [[simdgroup_index_in_threadgroup]], \
1126
  uint simd_lid [[thread_index_in_simdgroup]]);
1127
 
1128
- #define instantiate_paged_attention_heads(type, block_size, num_threads, \
1129
- num_simd_lanes, partition_size) \
1130
- instantiate_paged_attention_inner(type, 32, block_size, num_threads, \
1131
- num_simd_lanes, partition_size); \
1132
- instantiate_paged_attention_inner(type, 64, block_size, num_threads, \
1133
- num_simd_lanes, partition_size); \
1134
- instantiate_paged_attention_inner(type, 80, block_size, num_threads, \
1135
- num_simd_lanes, partition_size); \
1136
- instantiate_paged_attention_inner(type, 96, block_size, num_threads, \
1137
- num_simd_lanes, partition_size); \
1138
- instantiate_paged_attention_inner(type, 112, block_size, num_threads, \
1139
- num_simd_lanes, partition_size); \
1140
- instantiate_paged_attention_inner(type, 120, block_size, num_threads, \
1141
- num_simd_lanes, partition_size); \
1142
- instantiate_paged_attention_inner(type, 128, block_size, num_threads, \
1143
- num_simd_lanes, partition_size); \
1144
- instantiate_paged_attention_inner(type, 192, block_size, num_threads, \
1145
- num_simd_lanes, partition_size); \
1146
- instantiate_paged_attention_inner(type, 256, block_size, num_threads, \
1147
- num_simd_lanes, partition_size);
 
 
 
 
 
 
 
 
 
1148
 
1149
  #define instantiate_paged_attention_v2_reduce_heads( \
1150
  type, num_threads, num_simd_lanes, partition_size) \
@@ -1167,30 +1354,48 @@ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
1167
  instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \
1168
  num_simd_lanes, partition_size);
1169
 
1170
- #define instantiate_paged_attention_block_size(type, num_threads, \
1171
  num_simd_lanes, partition_size) \
1172
- instantiate_paged_attention_heads(type, 8, num_threads, num_simd_lanes, \
1173
- partition_size); \
1174
- instantiate_paged_attention_heads(type, 16, num_threads, num_simd_lanes, \
1175
- partition_size); \
1176
- instantiate_paged_attention_heads(type, 32, num_threads, num_simd_lanes, \
1177
- partition_size);
1178
 
1179
  // TODO: tune num_threads = 256
1180
  // NOTE: partition_size = 0
1181
- #define instantiate_paged_attention_v1(type, num_simd_lanes) \
1182
- instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0);
 
1183
 
1184
  // TODO: tune num_threads = 256
1185
  // NOTE: partition_size = 512
1186
- #define instantiate_paged_attention_v2(type, num_simd_lanes) \
1187
- instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \
 
 
 
 
 
1188
  instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
1189
 
1190
- instantiate_paged_attention_v1(float, 32);
1191
- instantiate_paged_attention_v1(bfloat16_t, 32);
1192
- instantiate_paged_attention_v1(half, 32);
 
 
 
 
 
 
 
 
 
 
 
 
1193
 
1194
- instantiate_paged_attention_v2(float, 32);
1195
- instantiate_paged_attention_v2(bfloat16_t, 32);
1196
- instantiate_paged_attention_v2(half, 32);
 
1
  // Updated from MLX commit has f70764a
2
 
3
  #include "../utils.metal"
4
+ #include "../float8.metal"
5
  #include <metal_simdgroup>
6
  #include <metal_stdlib>
7
 
 
530
  dst.y = y;
531
  }
532
 
533
+ // ========================================== FP8 (uchar) vector data types.
534
+
535
+ // 8‑lane uchar vector – Metal only provides up to uchar4, so build our own.
536
+ struct Uchar8_ {
537
+ uchar4 x;
538
+ uchar4 y;
539
+ };
540
+
541
+ // Vec specialisations so Vec<uchar, N>::Type resolves correctly.
542
+ template <> struct Vec<uchar, 1> {
543
+ using Type = uchar;
544
+ };
545
+ template <> struct Vec<uchar, 2> {
546
+ using Type = uchar2;
547
+ };
548
+ template <> struct Vec<uchar, 4> {
549
+ using Type = uchar4;
550
+ };
551
+ template <> struct Vec<uchar, 8> {
552
+ using Type = Uchar8_;
553
+ };
554
+
555
+ // General case: not uchar
556
+ template <typename T> inline constexpr bool is_uchar() { return false; }
557
+
558
+ // Specialization: T is uchar
559
+ template <> inline constexpr bool is_uchar<uchar>() { return true; }
560
+
561
+ // Generic fallback – will fail to compile if a required specialisation is
562
+ // missing.
563
+ template <typename Vec, typename Quant_vec>
564
+ inline Vec fp8_convert(const thread Quant_vec &, float scale) {
565
+ static_assert(sizeof(Vec) == 0, "Missing fp8_convert specialisation");
566
+ }
567
+
568
+ // ========================================== FP8 → float/half/bfloat
569
+ inline float __dequant_single(uchar v, float scale) {
570
+ return fp8_e4m3_to_float(v) * scale;
571
+ }
572
+
573
+ // ---- 1‑lane ----
574
+ template <>
575
+ inline float fp8_convert<float, uchar>(const thread uchar &in, float scale) {
576
+ return __dequant_single(in, scale);
577
+ }
578
+ template <>
579
+ inline half fp8_convert<half, uchar>(const thread uchar &in, float scale) {
580
+ return half(__dequant_single(in, scale));
581
+ }
582
+ template <>
583
+ inline bfloat16_t fp8_convert<bfloat16_t, uchar>(const thread uchar &in,
584
+ float scale) {
585
+ return bfloat16_t(__dequant_single(in, scale));
586
+ }
587
+
588
+ // ---- 2‑lane ----
589
+ template <>
590
+ inline float2 fp8_convert<float2, uchar2>(const thread uchar2 &in,
591
+ float scale) {
592
+ return float2(__dequant_single(in.x, scale), __dequant_single(in.y, scale));
593
+ }
594
+ template <>
595
+ inline half2 fp8_convert<half2, uchar2>(const thread uchar2 &in, float scale) {
596
+ half2 out;
597
+ out.x = half(__dequant_single(in.x, scale));
598
+ out.y = half(__dequant_single(in.y, scale));
599
+ return out;
600
+ }
601
+ template <>
602
+ inline Bfloat2_ fp8_convert<Bfloat2_, uchar2>(const thread uchar2 &in,
603
+ float scale) {
604
+ Bfloat2_ out;
605
+ out.x = bfloat16_t(__dequant_single(in.x, scale));
606
+ out.y = bfloat16_t(__dequant_single(in.y, scale));
607
+ return out;
608
+ }
609
+
610
+ // ---- 4‑lane ----
611
+ template <>
612
+ inline float4 fp8_convert<float4, uchar4>(const thread uchar4 &in,
613
+ float scale) {
614
+ return float4(__dequant_single(in.x, scale), __dequant_single(in.y, scale),
615
+ __dequant_single(in.z, scale), __dequant_single(in.w, scale));
616
+ }
617
+ template <>
618
+ inline half4 fp8_convert<half4, uchar4>(const thread uchar4 &in, float scale) {
619
+ half4 out;
620
+ out.x = half(__dequant_single(in.x, scale));
621
+ out.y = half(__dequant_single(in.y, scale));
622
+ out.z = half(__dequant_single(in.z, scale));
623
+ out.w = half(__dequant_single(in.w, scale));
624
+ return out;
625
+ }
626
+ template <>
627
+ inline Bfloat4_ fp8_convert<Bfloat4_, uchar4>(const thread uchar4 &in,
628
+ float scale) {
629
+ Bfloat4_ out;
630
+ out.x.x = bfloat16_t(__dequant_single(in.x, scale));
631
+ out.x.y = bfloat16_t(__dequant_single(in.y, scale));
632
+ out.y.x = bfloat16_t(__dequant_single(in.z, scale));
633
+ out.y.y = bfloat16_t(__dequant_single(in.w, scale));
634
+ return out;
635
+ }
636
+
637
+ // ---- 8‑lane ----
638
+ template <>
639
+ inline Float8_ fp8_convert<Float8_, Uchar8_>(const thread Uchar8_ &in,
640
+ float scale) {
641
+ Float8_ out;
642
+ out.x =
643
+ float4(__dequant_single(in.x.x, scale), __dequant_single(in.x.y, scale),
644
+ __dequant_single(in.x.z, scale), __dequant_single(in.x.w, scale));
645
+ out.y =
646
+ float4(__dequant_single(in.y.x, scale), __dequant_single(in.y.y, scale),
647
+ __dequant_single(in.y.z, scale), __dequant_single(in.y.w, scale));
648
+ return out;
649
+ }
650
+ template <>
651
+ inline Half8_ fp8_convert<Half8_, Uchar8_>(const thread Uchar8_ &in,
652
+ float scale) {
653
+ Half8_ out;
654
+ out.x = half4(half(__dequant_single(in.x.x, scale)),
655
+ half(__dequant_single(in.x.y, scale)),
656
+ half(__dequant_single(in.x.z, scale)),
657
+ half(__dequant_single(in.x.w, scale)));
658
+ out.y = half4(half(__dequant_single(in.y.x, scale)),
659
+ half(__dequant_single(in.y.y, scale)),
660
+ half(__dequant_single(in.y.z, scale)),
661
+ half(__dequant_single(in.y.w, scale)));
662
+ return out;
663
+ }
664
+ template <>
665
+ inline Bfloat8_ fp8_convert<Bfloat8_, Uchar8_>(const thread Uchar8_ &in,
666
+ float scale) {
667
+ Bfloat8_ out;
668
+ // first 4
669
+ out.x.x.x = bfloat16_t(__dequant_single(in.x.x, scale));
670
+ out.x.x.y = bfloat16_t(__dequant_single(in.x.y, scale));
671
+ out.x.y.x = bfloat16_t(__dequant_single(in.x.z, scale));
672
+ out.x.y.y = bfloat16_t(__dequant_single(in.x.w, scale));
673
+ // second 4
674
+ out.y.x.x = bfloat16_t(__dequant_single(in.y.x, scale));
675
+ out.y.x.y = bfloat16_t(__dequant_single(in.y.y, scale));
676
+ out.y.y.x = bfloat16_t(__dequant_single(in.y.z, scale));
677
+ out.y.y.y = bfloat16_t(__dequant_single(in.y.w, scale));
678
+ return out;
679
+ }
680
+
681
  // ========================================== Dot product utilities
682
 
683
  // TODO(EricLBuehler): optimize with vectorization
 
751
 
752
  constant bool use_partitioning [[function_constant(10)]];
753
  constant bool use_alibi [[function_constant(20)]];
754
+ constant bool use_fp8_scales [[function_constant(30)]];
755
 
756
+ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
757
  int NUM_SIMD_LANES, int PARTITION_SIZE = 0>
758
  [[kernel]] void paged_attention(
759
  device float *exp_sums
 
765
  device T *out
766
  [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size]
767
  device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size]
768
+ device const CACHE_T *k_cache
769
  [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x]
770
+ device const CACHE_T *v_cache
771
  [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size]
772
+ const device float *__restrict__ k_scale
773
+ [[buffer(6)]], // [1] - only used when use_fp8_scales
774
+ const device float *__restrict__ v_scale
775
+ [[buffer(7)]], // [1] - only used when use_fp8_scales
776
+ const constant int &num_kv_heads [[buffer(8)]], // [num_heads]
777
+ const constant float &scale [[buffer(9)]],
778
+ const constant float &softcapping [[buffer(10)]],
779
  device const uint32_t *block_tables
780
+ [[buffer(11)]], // [num_seqs, max_num_blocks_per_seq]
781
+ device const uint32_t *context_lens [[buffer(12)]], // [num_seqs]
782
+ const constant int &max_num_blocks_per_seq [[buffer(13)]],
783
  device const float *alibi_slopes
784
+ [[buffer(14)]], // [num_heads] - only used when use_alibi
785
+ const constant int &q_stride [[buffer(15)]],
786
+ const constant int &kv_block_stride [[buffer(16)]],
787
+ const constant int &kv_head_stride [[buffer(17)]],
788
  threadgroup char *shared_mem [[threadgroup(0)]],
789
  uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
790
  uint3 threadgroups_per_grid [[threadgroups_per_grid]],
 
844
  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1);
845
  using K_vec = typename Vec<T, VEC_SIZE>::Type;
846
  using Q_vec = typename Vec<T, VEC_SIZE>::Type;
847
+ using Quant_vec = typename Vec<CACHE_T, VEC_SIZE>::Type;
848
 
849
  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
850
  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
 
875
 
876
  // x == THREAD_GROUP_SIZE * VEC_SIZE
877
  // Each thread group fetches x elements from the key at a time.
878
+ constexpr int x = 16 / sizeof(CACHE_T);
879
  float qk_max = -FLT_MAX;
880
 
881
  // Iterate over the key blocks.
 
905
 
906
  #pragma unroll
907
  for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
908
+ const device CACHE_T *k_ptr =
909
  k_cache + physical_block_number * kv_block_stride +
910
  kv_head_idx * kv_head_stride + physical_block_offset * x;
911
  const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
912
  const int offset1 = (vec_idx * VEC_SIZE) / x;
913
  const int offset2 = (vec_idx * VEC_SIZE) % x;
914
+
915
+ if constexpr (is_uchar<CACHE_T>()) {
916
+ // FP8 support
917
+ Quant_vec k_vec_quant = *reinterpret_cast<const device Quant_vec *>(
918
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
919
+ k_vecs[j] = fp8_convert<K_vec, Quant_vec>(k_vec_quant, *k_scale);
920
+ } else {
921
+ // Non-FP8 default
922
+ k_vecs[j] = *reinterpret_cast<const device K_vec *>(
923
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
924
+ }
925
  }
926
 
927
  // Compute dot product.
 
1008
  using V_vec = typename Vec<T, V_VEC_SIZE>::Type;
1009
  using L_vec = typename Vec<T, V_VEC_SIZE>::Type;
1010
  using Float_L_vec = typename FloatVec<L_vec>::Type;
1011
+ using V_quant_vec = typename Vec<CACHE_T, V_VEC_SIZE>::Type;
1012
 
1013
  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
1014
  constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW;
 
1037
  logits + token_idx - start_token_idx);
1038
  from_float(logits_vec, logits_float_vec);
1039
 
1040
+ const device CACHE_T *v_ptr = v_cache + physical_block_number * kv_block_stride +
1041
+ kv_head_idx * kv_head_stride;
1042
  #pragma unroll
1043
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
1044
  const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
 
1048
  // we should explicitly zero out the values since they may contain NaNs.
1049
  // See
1050
  // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
1051
+ V_vec v_vec;
1052
+
1053
+ if constexpr (is_uchar<CACHE_T>()) {
1054
+ // FP8 support
1055
+ V_quant_vec v_quant_vec =
1056
+ *reinterpret_cast<const device V_quant_vec *>(v_ptr + offset);
1057
+ v_vec = fp8_convert<V_vec, V_quant_vec>(v_quant_vec, *v_scale);
1058
+ } else {
1059
+ // Non-FP8 default
1060
+ v_vec = *reinterpret_cast<const device V_vec *>(v_ptr + offset);
1061
+ }
1062
+
1063
  if (block_idx == num_context_blocks - 1) {
1064
  thread T *v_vec_ptr = reinterpret_cast<thread T *>(&v_vec);
1065
  #pragma unroll
 
1249
  }
1250
  }
1251
 
1252
+ #define instantiate_paged_attention_inner(type, cache_type, head_size, \
1253
+ block_size, num_threads, \
1254
+ num_simd_lanes, partition_size) \
1255
+ template [[host_name("paged_attention_" #type "_cache_" #cache_type \
1256
+ "_hs" #head_size "_bs" #block_size "_nt" #num_threads \
1257
+ "_nsl" #num_simd_lanes \
1258
+ "_ps" #partition_size)]] [[kernel]] void \
1259
+ paged_attention<type, cache_type, head_size, block_size, num_threads, \
1260
+ num_simd_lanes, partition_size>( \
1261
+ device float *exp_sums [[buffer(0)]], \
1262
+ device float *max_logits [[buffer(1)]], \
1263
+ device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \
1264
+ device const cache_type *k_cache [[buffer(4)]], \
1265
+ device const cache_type *v_cache [[buffer(5)]], \
1266
+ const device float *__restrict__ k_scale [[buffer(6)]], \
1267
+ const device float *__restrict__ v_scale [[buffer(7)]], \
1268
+ const constant int &num_kv_heads [[buffer(8)]], \
1269
+ const constant float &scale [[buffer(9)]], \
1270
+ const constant float &softcapping [[buffer(10)]], \
1271
+ device const uint32_t *block_tables [[buffer(11)]], \
1272
+ device const uint32_t *context_lens [[buffer(12)]], \
1273
+ const constant int &max_num_blocks_per_seq [[buffer(13)]], \
1274
+ device const float *alibi_slopes [[buffer(14)]], \
1275
+ const constant int &q_stride [[buffer(15)]], \
1276
+ const constant int &kv_block_stride [[buffer(16)]], \
1277
+ const constant int &kv_head_stride [[buffer(17)]], \
1278
+ threadgroup char *shared_mem [[threadgroup(0)]], \
1279
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
1280
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
1281
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
1282
+ uint simd_tid [[simdgroup_index_in_threadgroup]], \
1283
+ uint simd_lid [[thread_index_in_simdgroup]]);
1284
 
1285
  #define instantiate_paged_attention_v2_reduce_inner( \
1286
  type, head_size, num_threads, num_simd_lanes, partition_size) \
 
1303
  uint simd_tid [[simdgroup_index_in_threadgroup]], \
1304
  uint simd_lid [[thread_index_in_simdgroup]]);
1305
 
1306
+ #define instantiate_paged_attention_heads( \
1307
+ type, cache_type, block_size, num_threads, num_simd_lanes, partition_size) \
1308
+ instantiate_paged_attention_inner(type, cache_type, 32, block_size, \
1309
+ num_threads, num_simd_lanes, \
1310
+ partition_size); \
1311
+ instantiate_paged_attention_inner(type, cache_type, 64, block_size, \
1312
+ num_threads, num_simd_lanes, \
1313
+ partition_size); \
1314
+ instantiate_paged_attention_inner(type, cache_type, 80, block_size, \
1315
+ num_threads, num_simd_lanes, \
1316
+ partition_size); \
1317
+ instantiate_paged_attention_inner(type, cache_type, 96, block_size, \
1318
+ num_threads, num_simd_lanes, \
1319
+ partition_size); \
1320
+ instantiate_paged_attention_inner(type, cache_type, 112, block_size, \
1321
+ num_threads, num_simd_lanes, \
1322
+ partition_size); \
1323
+ instantiate_paged_attention_inner(type, cache_type, 120, block_size, \
1324
+ num_threads, num_simd_lanes, \
1325
+ partition_size); \
1326
+ instantiate_paged_attention_inner(type, cache_type, 128, block_size, \
1327
+ num_threads, num_simd_lanes, \
1328
+ partition_size); \
1329
+ instantiate_paged_attention_inner(type, cache_type, 192, block_size, \
1330
+ num_threads, num_simd_lanes, \
1331
+ partition_size); \
1332
+ instantiate_paged_attention_inner(type, cache_type, 256, block_size, \
1333
+ num_threads, num_simd_lanes, \
1334
+ partition_size);
1335
 
1336
  #define instantiate_paged_attention_v2_reduce_heads( \
1337
  type, num_threads, num_simd_lanes, partition_size) \
 
1354
  instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \
1355
  num_simd_lanes, partition_size);
1356
 
1357
+ #define instantiate_paged_attention_block_size(type, cache_type, num_threads, \
1358
  num_simd_lanes, partition_size) \
1359
+ instantiate_paged_attention_heads(type, cache_type, 8, num_threads, \
1360
+ num_simd_lanes, partition_size); \
1361
+ instantiate_paged_attention_heads(type, cache_type, 16, num_threads, \
1362
+ num_simd_lanes, partition_size); \
1363
+ instantiate_paged_attention_heads(type, cache_type, 32, num_threads, \
1364
+ num_simd_lanes, partition_size);
1365
 
1366
  // TODO: tune num_threads = 256
1367
  // NOTE: partition_size = 0
1368
+ #define instantiate_paged_attention_v1(type, cache_type, num_simd_lanes) \
1369
+ instantiate_paged_attention_block_size(type, cache_type, 256, \
1370
+ num_simd_lanes, 0);
1371
 
1372
  // TODO: tune num_threads = 256
1373
  // NOTE: partition_size = 512
1374
+ #define instantiate_paged_attention_v2(type, cache_type, num_simd_lanes) \
1375
+ instantiate_paged_attention_block_size(type, cache_type, 256, \
1376
+ num_simd_lanes, 512);
1377
+
1378
+ // TODO: tune num_threads = 256
1379
+ // NOTE: partition_size = 512
1380
+ #define instantiate_paged_attention_v2_reduce(type, num_simd_lanes) \
1381
  instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
1382
 
1383
+ instantiate_paged_attention_v1(float, float, 32);
1384
+ instantiate_paged_attention_v1(bfloat16_t, bfloat16_t, 32);
1385
+ instantiate_paged_attention_v1(half, half, 32);
1386
+
1387
+ instantiate_paged_attention_v1(float, uchar, 32);
1388
+ instantiate_paged_attention_v1(bfloat16_t, uchar, 32);
1389
+ instantiate_paged_attention_v1(half, uchar, 32);
1390
+
1391
+ instantiate_paged_attention_v2_reduce(float, 32);
1392
+ instantiate_paged_attention_v2_reduce(bfloat16_t, 32);
1393
+ instantiate_paged_attention_v2_reduce(half, 32);
1394
+
1395
+ instantiate_paged_attention_v2(float, float, 32);
1396
+ instantiate_paged_attention_v2(bfloat16_t, bfloat16_t, 32);
1397
+ instantiate_paged_attention_v2(half, half, 32);
1398
 
1399
+ instantiate_paged_attention_v2(float, uchar, 32);
1400
+ instantiate_paged_attention_v2(bfloat16_t, uchar, 32);
1401
+ instantiate_paged_attention_v2(half, uchar, 32);
paged-attention-metal/cache.mm CHANGED
@@ -147,6 +147,9 @@ void copy_blocks(const std::vector<torch::Tensor> &key_caches,
147
  case torch::kBFloat16:
148
  kernName = @"copy_blocks_bfloat16_t";
149
  break;
 
 
 
150
  default:
151
  TORCH_CHECK(false, "Unsupported dtype for copy_blocks");
152
  }
@@ -214,6 +217,16 @@ void reshape_and_cache(
214
  const std::string &kv_cache_dtype, torch::Tensor &k_scale,
215
  torch::Tensor &v_scale) {
216
 
 
 
 
 
 
 
 
 
 
 
217
  TORCH_CHECK(key.device().is_mps() && value.device().is_mps() &&
218
  key_cache.device().is_mps() && value_cache.device().is_mps(),
219
  "All tensors must be on MPS device");
@@ -256,22 +269,51 @@ void reshape_and_cache(
256
  }
257
 
258
  NSString *kernName = nil;
 
 
 
259
  switch (key.scalar_type()) {
260
  case torch::kFloat:
261
- kernName = @"reshape_and_cache_float";
262
  break;
263
  case torch::kHalf:
264
- kernName = @"reshape_and_cache_half";
265
  break;
266
  case torch::kBFloat16:
267
- kernName = @"reshape_and_cache_bfloat16_t";
268
  break;
269
  default:
270
  TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache");
271
  }
272
-
273
- id<MTLFunction> fn = [lib newFunctionWithName:kernName];
274
- TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  id<MTLComputePipelineState> pso =
277
  [device newComputePipelineStateWithFunction:fn error:&error];
@@ -305,46 +347,59 @@ void reshape_and_cache(
305
  options:MTLResourceStorageModeShared];
306
  [enc setBuffer:slotMappingBuf offset:0 atIndex:4];
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  // Set parameters as individual buffers (matching mistralrs pattern)
309
  id<MTLBuffer> keyStrideBuf =
310
  [device newBufferWithBytes:&key_stride
311
  length:sizeof(int32_t)
312
  options:MTLResourceStorageModeShared];
313
- [enc setBuffer:keyStrideBuf offset:0 atIndex:5];
314
 
315
  id<MTLBuffer> valueStrideBuf =
316
  [device newBufferWithBytes:&value_stride
317
  length:sizeof(int32_t)
318
  options:MTLResourceStorageModeShared];
319
- [enc setBuffer:valueStrideBuf offset:0 atIndex:6];
320
 
321
  const int32_t num_heads_i32 = static_cast<int32_t>(num_heads);
322
  id<MTLBuffer> numHeadsBuf =
323
  [device newBufferWithBytes:&num_heads_i32
324
  length:sizeof(int32_t)
325
  options:MTLResourceStorageModeShared];
326
- [enc setBuffer:numHeadsBuf offset:0 atIndex:7];
327
 
328
  const int32_t head_size_i32 = static_cast<int32_t>(head_size);
329
  id<MTLBuffer> headSizeBuf =
330
  [device newBufferWithBytes:&head_size_i32
331
  length:sizeof(int32_t)
332
  options:MTLResourceStorageModeShared];
333
- [enc setBuffer:headSizeBuf offset:0 atIndex:8];
334
 
335
  const int32_t block_size_i32 = static_cast<int32_t>(block_size);
336
  id<MTLBuffer> blockSizeBuf =
337
  [device newBufferWithBytes:&block_size_i32
338
  length:sizeof(int32_t)
339
  options:MTLResourceStorageModeShared];
340
- [enc setBuffer:blockSizeBuf offset:0 atIndex:9];
341
 
342
  const int32_t x_i32 = static_cast<int32_t>(x);
343
  id<MTLBuffer> xBuf =
344
  [device newBufferWithBytes:&x_i32
345
  length:sizeof(int32_t)
346
  options:MTLResourceStorageModeShared];
347
- [enc setBuffer:xBuf offset:0 atIndex:10];
348
 
349
  const uint64_t threads_per_threadgroup =
350
  std::min<uint64_t>(512, num_heads * head_size);
 
147
  case torch::kBFloat16:
148
  kernName = @"copy_blocks_bfloat16_t";
149
  break;
150
+ case torch::kUInt8:
151
+ kernName = @"copy_blocks_uchar";
152
+ break;
153
  default:
154
  TORCH_CHECK(false, "Unsupported dtype for copy_blocks");
155
  }
 
217
  const std::string &kv_cache_dtype, torch::Tensor &k_scale,
218
  torch::Tensor &v_scale) {
219
 
220
+ // Determine cache dtype and FP8 usage
221
+ torch::ScalarType cache_dtype = key_cache.scalar_type();
222
+ bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3");
223
+ if (use_fp8_scales) {
224
+ TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type");
225
+ TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars");
226
+ TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32,
227
+ "FP8 scales must be float32");
228
+ }
229
+
230
  TORCH_CHECK(key.device().is_mps() && value.device().is_mps() &&
231
  key_cache.device().is_mps() && value_cache.device().is_mps(),
232
  "All tensors must be on MPS device");
 
269
  }
270
 
271
  NSString *kernName = nil;
272
+ std::string kv_dtype_str, cache_dtype_str;
273
+
274
+ // Get KV dtype string
275
  switch (key.scalar_type()) {
276
  case torch::kFloat:
277
+ kv_dtype_str = "float";
278
  break;
279
  case torch::kHalf:
280
+ kv_dtype_str = "half";
281
  break;
282
  case torch::kBFloat16:
283
+ kv_dtype_str = "bfloat16_t";
284
  break;
285
  default:
286
  TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache");
287
  }
288
+
289
+ // Get cache dtype string
290
+ switch (cache_dtype) {
291
+ case torch::kFloat:
292
+ cache_dtype_str = "float";
293
+ break;
294
+ case torch::kHalf:
295
+ cache_dtype_str = "half";
296
+ break;
297
+ case torch::kBFloat16:
298
+ cache_dtype_str = "bfloat16_t";
299
+ break;
300
+ case torch::kUInt8:
301
+ cache_dtype_str = "uchar";
302
+ break;
303
+ default:
304
+ TORCH_CHECK(false, "Unsupported cache dtype for reshape_and_cache");
305
+ }
306
+
307
+ std::string kernName_str = "reshape_and_cache_kv_" + kv_dtype_str + "_cache_" + cache_dtype_str;
308
+ kernName = [NSString stringWithUTF8String:kernName_str.c_str()];
309
+
310
+ // Create function constants for FP8 support
311
+ MTLFunctionConstantValues *constants = [[MTLFunctionConstantValues alloc] init];
312
+ [constants setConstantValue:&use_fp8_scales type:MTLDataTypeBool atIndex:10];
313
+
314
+ id<MTLFunction> fn = [lib newFunctionWithName:kernName constantValues:constants error:&error];
315
+ TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String,
316
+ error ? [NSString stringWithFormat:@": %@", error.localizedDescription].UTF8String : "");
317
 
318
  id<MTLComputePipelineState> pso =
319
  [device newComputePipelineStateWithFunction:fn error:&error];
 
347
  options:MTLResourceStorageModeShared];
348
  [enc setBuffer:slotMappingBuf offset:0 atIndex:4];
349
 
350
+ // k_scale and v_scale buffers (for FP8)
351
+ if (use_fp8_scales) {
352
+ [enc setBuffer:getMTLBufferStorage(k_scale)
353
+ offset:k_scale.storage_offset() * k_scale.element_size()
354
+ atIndex:5];
355
+ [enc setBuffer:getMTLBufferStorage(v_scale)
356
+ offset:v_scale.storage_offset() * v_scale.element_size()
357
+ atIndex:6];
358
+ } else {
359
+ // For non-FP8, we still need to increment buffer indices
360
+ // The Metal kernel expects buffers at indices 5 and 6 even if unused
361
+ }
362
+
363
  // Set parameters as individual buffers (matching mistralrs pattern)
364
  id<MTLBuffer> keyStrideBuf =
365
  [device newBufferWithBytes:&key_stride
366
  length:sizeof(int32_t)
367
  options:MTLResourceStorageModeShared];
368
+ [enc setBuffer:keyStrideBuf offset:0 atIndex:7];
369
 
370
  id<MTLBuffer> valueStrideBuf =
371
  [device newBufferWithBytes:&value_stride
372
  length:sizeof(int32_t)
373
  options:MTLResourceStorageModeShared];
374
+ [enc setBuffer:valueStrideBuf offset:0 atIndex:8];
375
 
376
  const int32_t num_heads_i32 = static_cast<int32_t>(num_heads);
377
  id<MTLBuffer> numHeadsBuf =
378
  [device newBufferWithBytes:&num_heads_i32
379
  length:sizeof(int32_t)
380
  options:MTLResourceStorageModeShared];
381
+ [enc setBuffer:numHeadsBuf offset:0 atIndex:9];
382
 
383
  const int32_t head_size_i32 = static_cast<int32_t>(head_size);
384
  id<MTLBuffer> headSizeBuf =
385
  [device newBufferWithBytes:&head_size_i32
386
  length:sizeof(int32_t)
387
  options:MTLResourceStorageModeShared];
388
+ [enc setBuffer:headSizeBuf offset:0 atIndex:10];
389
 
390
  const int32_t block_size_i32 = static_cast<int32_t>(block_size);
391
  id<MTLBuffer> blockSizeBuf =
392
  [device newBufferWithBytes:&block_size_i32
393
  length:sizeof(int32_t)
394
  options:MTLResourceStorageModeShared];
395
+ [enc setBuffer:blockSizeBuf offset:0 atIndex:11];
396
 
397
  const int32_t x_i32 = static_cast<int32_t>(x);
398
  id<MTLBuffer> xBuf =
399
  [device newBufferWithBytes:&x_i32
400
  length:sizeof(int32_t)
401
  options:MTLResourceStorageModeShared];
402
+ [enc setBuffer:xBuf offset:0 atIndex:12];
403
 
404
  const uint64_t threads_per_threadgroup =
405
  std::min<uint64_t>(512, num_heads * head_size);
paged-attention-metal/cache/copy_blocks.metal CHANGED
@@ -48,3 +48,4 @@ template <typename T>
48
  instantiate_copy_blocks(float);
49
  instantiate_copy_blocks(bfloat16_t);
50
  instantiate_copy_blocks(half);
 
 
48
  instantiate_copy_blocks(float);
49
  instantiate_copy_blocks(bfloat16_t);
50
  instantiate_copy_blocks(half);
51
+ instantiate_copy_blocks(uchar);
paged-attention-metal/cache/reshape_and_cache.metal CHANGED
@@ -1,23 +1,56 @@
1
  #include "../utils.metal"
 
2
  #include <metal_stdlib>
3
 
4
  using namespace metal;
5
 
6
- template <typename T>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  [[kernel]] void reshape_and_cache(
8
- const device T *__restrict__ key
9
  [[buffer(0)]], // [num_tokens, num_heads, head_size]
10
- const device T *__restrict__ value
11
  [[buffer(1)]], // [num_tokens, num_heads, head_size]
12
- device T *__restrict__ key_cache
13
  [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x]
14
- device T *__restrict__ value_cache
15
  [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size]
16
  const device int64_t *__restrict__ slot_mapping
17
  [[buffer(4)]], // [num_tokens]
18
- device const int &key_stride, device const int &value_stride,
19
- device const int &num_heads, device const int &head_size,
20
- device const int &block_size, device const int &x,
 
 
 
 
 
 
 
21
  uint gid [[threadgroup_position_in_grid]],
22
  uint tid [[thread_position_in_threadgroup]],
23
  uint threads_per_threadgroup [[threads_per_threadgroup]]) {
@@ -49,29 +82,47 @@ template <typename T>
49
  block_idx * num_heads * head_size * block_size +
50
  head_idx * head_size * block_size + head_offset * block_size +
51
  block_offset;
52
- key_cache[tgt_key_idx] = key[src_key_idx];
53
- value_cache[tgt_value_idx] = value[src_value_idx];
 
 
 
 
 
 
 
 
54
  }
55
  }
56
 
57
- #define instantiate_reshape_and_cache(type) \
58
- template [[host_name("reshape_and_cache_" #type)]] [[kernel]] void \
59
- reshape_and_cache<type>( \
60
- const device type *__restrict__ key [[buffer(0)]], \
61
- const device type *__restrict__ value [[buffer(1)]], \
62
- device type *__restrict__ key_cache [[buffer(2)]], \
63
- device type *__restrict__ value_cache [[buffer(3)]], \
 
64
  const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \
65
- device const int &key_stride, device const int &value_stride, \
66
- device const int &num_heads, device const int &head_size, \
67
- device const int &block_size, device const int &x, \
 
 
 
 
 
68
  uint gid [[threadgroup_position_in_grid]], \
69
  uint tid [[thread_position_in_threadgroup]], \
70
  uint threads_per_threadgroup [[threads_per_threadgroup]]);
71
 
72
- instantiate_reshape_and_cache(float);
73
- instantiate_reshape_and_cache(bfloat16_t);
74
- instantiate_reshape_and_cache(half);
 
 
 
 
75
 
76
  // Flash version with different cache layout: [num_blocks, block_size,
77
  // num_heads, head_size]
 
1
  #include "../utils.metal"
2
+ #include "../float8.metal"
3
  #include <metal_stdlib>
4
 
5
  using namespace metal;
6
 
7
+ template <typename KV_T, typename CACHE_T>
8
+ inline CACHE_T to_cache(KV_T v) = delete;
9
+
10
+ template <> inline uchar to_cache<float, uchar>(float v) {
11
+ return float_to_fp8_e4m3(v);
12
+ }
13
+
14
+ template <> inline uchar to_cache<bfloat16_t, uchar>(bfloat16_t v) {
15
+ return float_to_fp8_e4m3((float)v);
16
+ }
17
+
18
+ template <> inline uchar to_cache<half, uchar>(half v) {
19
+ return float_to_fp8_e4m3((float)v);
20
+ }
21
+
22
+ template <> inline float to_cache<float, float>(float v) { return v; }
23
+
24
+ template <> inline bfloat16_t to_cache<bfloat16_t, bfloat16_t>(bfloat16_t v) {
25
+ return v;
26
+ }
27
+
28
+ template <> inline half to_cache<half, half>(half v) { return v; }
29
+
30
+ constant bool use_fp8_scales [[function_constant(10)]];
31
+
32
+ template <typename KV_T, typename CACHE_T>
33
  [[kernel]] void reshape_and_cache(
34
+ const device KV_T *__restrict__ key
35
  [[buffer(0)]], // [num_tokens, num_heads, head_size]
36
+ const device KV_T *__restrict__ value
37
  [[buffer(1)]], // [num_tokens, num_heads, head_size]
38
+ device CACHE_T *__restrict__ key_cache
39
  [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x]
40
+ device CACHE_T *__restrict__ value_cache
41
  [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size]
42
  const device int64_t *__restrict__ slot_mapping
43
  [[buffer(4)]], // [num_tokens]
44
+ const device float *__restrict__ k_scale
45
+ [[buffer(5)]], // [1] - only used when use_fp8_scales
46
+ const device float *__restrict__ v_scale
47
+ [[buffer(6)]], // [1] - only used when use_fp8_scales
48
+ device const int &key_stride [[buffer(7)]],
49
+ device const int &value_stride [[buffer(8)]],
50
+ device const int &num_heads [[buffer(9)]],
51
+ device const int &head_size [[buffer(10)]],
52
+ device const int &block_size [[buffer(11)]],
53
+ device const int &x [[buffer(12)]],
54
  uint gid [[threadgroup_position_in_grid]],
55
  uint tid [[thread_position_in_threadgroup]],
56
  uint threads_per_threadgroup [[threads_per_threadgroup]]) {
 
82
  block_idx * num_heads * head_size * block_size +
83
  head_idx * head_size * block_size + head_offset * block_size +
84
  block_offset;
85
+
86
+ if (use_fp8_scales) {
87
+ key_cache[tgt_key_idx] =
88
+ to_cache<KV_T, CACHE_T>(KV_T((float)key[src_key_idx] / *k_scale));
89
+ value_cache[tgt_value_idx] =
90
+ to_cache<KV_T, CACHE_T>(KV_T((float)value[src_value_idx] / *v_scale));
91
+ } else {
92
+ key_cache[tgt_key_idx] = to_cache<KV_T, CACHE_T>(key[src_key_idx]);
93
+ value_cache[tgt_value_idx] = to_cache<KV_T, CACHE_T>(value[src_value_idx]);
94
+ }
95
  }
96
  }
97
 
98
+ #define instantiate_reshape_and_cache(kv_type, cache_type) \
99
+ template [[host_name("reshape_and_cache_kv_" #kv_type \
100
+ "_cache_" #cache_type)]] [[kernel]] void \
101
+ reshape_and_cache<kv_type, cache_type>( \
102
+ const device kv_type *__restrict__ key [[buffer(0)]], \
103
+ const device kv_type *__restrict__ value [[buffer(1)]], \
104
+ device cache_type *__restrict__ key_cache [[buffer(2)]], \
105
+ device cache_type *__restrict__ value_cache [[buffer(3)]], \
106
  const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \
107
+ const device float *__restrict__ k_scale [[buffer(5)]], \
108
+ const device float *__restrict__ v_scale [[buffer(6)]], \
109
+ device const int &key_stride [[buffer(7)]], \
110
+ device const int &value_stride [[buffer(8)]], \
111
+ device const int &num_heads [[buffer(9)]], \
112
+ device const int &head_size [[buffer(10)]], \
113
+ device const int &block_size [[buffer(11)]], \
114
+ device const int &x [[buffer(12)]], \
115
  uint gid [[threadgroup_position_in_grid]], \
116
  uint tid [[thread_position_in_threadgroup]], \
117
  uint threads_per_threadgroup [[threads_per_threadgroup]]);
118
 
119
+ instantiate_reshape_and_cache(float, float);
120
+ instantiate_reshape_and_cache(bfloat16_t, bfloat16_t);
121
+ instantiate_reshape_and_cache(half, half);
122
+
123
+ instantiate_reshape_and_cache(float, uchar);
124
+ instantiate_reshape_and_cache(bfloat16_t, uchar);
125
+ instantiate_reshape_and_cache(half, uchar);
126
 
127
  // Flash version with different cache layout: [num_blocks, block_size,
128
  // num_heads, head_size]
paged-attention-metal/convert_fp8.metal ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "float8.metal"
2
+ #include "utils.metal"
3
+ #include <metal_stdlib>
4
+
5
+ using namespace metal;
6
+
7
+ // Convert between different precision formats for cache tensors
8
+ // This kernel handles conversions like float->fp8, fp8->float, etc.
9
+
10
+ template <typename SRC_T, typename DST_T>
11
+ [[kernel]] void convert_fp8_kernel(
12
+ const device SRC_T *__restrict__ src [[buffer(0)]],
13
+ device DST_T *__restrict__ dst [[buffer(1)]],
14
+ const device float &scale [[buffer(2)]],
15
+ const device uint32_t &num_elements [[buffer(3)]],
16
+ uint gid [[thread_position_in_grid]]) {
17
+
18
+ if (gid >= num_elements) {
19
+ return;
20
+ }
21
+
22
+ // Load source value
23
+ SRC_T src_val = src[gid];
24
+
25
+ // Convert based on source and destination types
26
+ if constexpr (is_same_v<SRC_T, uchar> && !is_same_v<DST_T, uchar>) {
27
+ // FP8 -> higher precision (dequantization)
28
+ float fp32_val = fp8_e4m3_to_float(src_val) * scale;
29
+ dst[gid] = static_cast<DST_T>(fp32_val);
30
+ } else if constexpr (!is_same_v<SRC_T, uchar> && is_same_v<DST_T, uchar>) {
31
+ // Higher precision -> FP8 (quantization)
32
+ float fp32_val = static_cast<float>(src_val) / scale;
33
+ dst[gid] = float_to_fp8_e4m3(fp32_val);
34
+ } else if constexpr (is_same_v<SRC_T, uchar> && is_same_v<DST_T, uchar>) {
35
+ // FP8 -> FP8 (with rescaling)
36
+ float fp32_val = fp8_e4m3_to_float(src_val) * scale;
37
+ dst[gid] = float_to_fp8_e4m3(fp32_val);
38
+ } else {
39
+ // Regular precision -> regular precision (with scaling)
40
+ float fp32_val = static_cast<float>(src_val) * scale;
41
+ dst[gid] = static_cast<DST_T>(fp32_val);
42
+ }
43
+ }
44
+
45
+ // Instantiate all required combinations
46
+ #define INSTANTIATE_CONVERT_FP8(src_type, dst_type) \
47
+ template [[host_name("convert_fp8_" #src_type "_to_" #dst_type)]] \
48
+ [[kernel]] void convert_fp8_kernel<src_type, dst_type>( \
49
+ const device src_type *__restrict__ src [[buffer(0)]], \
50
+ device dst_type *__restrict__ dst [[buffer(1)]], \
51
+ const device float &scale [[buffer(2)]], \
52
+ const device uint32_t &num_elements [[buffer(3)]], \
53
+ uint gid [[thread_position_in_grid]]);
54
+
55
+ // FP8 to other formats (dequantization)
56
+ INSTANTIATE_CONVERT_FP8(uchar, float);
57
+ INSTANTIATE_CONVERT_FP8(uchar, half);
58
+ INSTANTIATE_CONVERT_FP8(uchar, bfloat16_t);
59
+
60
+ // Other formats to FP8 (quantization)
61
+ INSTANTIATE_CONVERT_FP8(float, uchar);
62
+ INSTANTIATE_CONVERT_FP8(half, uchar);
63
+ INSTANTIATE_CONVERT_FP8(bfloat16_t, uchar);
64
+
65
+ // FP8 to FP8 (rescaling)
66
+ INSTANTIATE_CONVERT_FP8(uchar, uchar);
67
+
68
+ // Regular precision conversions with scaling
69
+ INSTANTIATE_CONVERT_FP8(float, float);
70
+ INSTANTIATE_CONVERT_FP8(float, half);
71
+ INSTANTIATE_CONVERT_FP8(float, bfloat16_t);
72
+ INSTANTIATE_CONVERT_FP8(half, float);
73
+ INSTANTIATE_CONVERT_FP8(half, half);
74
+ INSTANTIATE_CONVERT_FP8(half, bfloat16_t);
75
+ INSTANTIATE_CONVERT_FP8(bfloat16_t, float);
76
+ INSTANTIATE_CONVERT_FP8(bfloat16_t, half);
77
+ INSTANTIATE_CONVERT_FP8(bfloat16_t, bfloat16_t);
paged-attention-metal/convert_fp8.mm CHANGED
@@ -1,3 +1,5 @@
 
 
1
  #include <torch/torch.h>
2
 
3
  #import <Foundation/Foundation.h>
@@ -24,7 +26,113 @@ static std::string getModuleDirectory() {
24
  return ".";
25
  }
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  void convert_fp8(torch::Tensor &dst_cache, torch::Tensor &src_cache,
28
  const double scale, const std::string &kv_cache_dtype) {
29
- TORCH_CHECK(false, "fp8 is not supported on Metal.");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  }
 
1
+ #include <ATen/mps/MPSDevice.h>
2
+ #include <ATen/mps/MPSStream.h>
3
  #include <torch/torch.h>
4
 
5
  #import <Foundation/Foundation.h>
 
26
  return ".";
27
  }
28
 
29
+ // Helper function to get conversion kernel name
30
+ static std::string getConvertKernelName(torch::ScalarType src_dtype, torch::ScalarType dst_dtype) {
31
+ std::string src_str, dst_str;
32
+
33
+ auto dtype_to_string = [](torch::ScalarType dtype) -> std::string {
34
+ switch (dtype) {
35
+ case torch::kFloat: return "float";
36
+ case torch::kHalf: return "half";
37
+ case torch::kBFloat16: return "bfloat16_t";
38
+ case torch::kUInt8: return "uchar";
39
+ default:
40
+ TORCH_CHECK(false, "Unsupported dtype for convert_fp8: ", dtype);
41
+ }
42
+ };
43
+
44
+ src_str = dtype_to_string(src_dtype);
45
+ dst_str = dtype_to_string(dst_dtype);
46
+
47
+ return "convert_fp8_" + src_str + "_to_" + dst_str;
48
+ }
49
+
50
  void convert_fp8(torch::Tensor &dst_cache, torch::Tensor &src_cache,
51
  const double scale, const std::string &kv_cache_dtype) {
52
+ // Validate input tensors
53
+ TORCH_CHECK(src_cache.device().is_mps() && dst_cache.device().is_mps(),
54
+ "Both tensors must be on MPS device");
55
+ TORCH_CHECK(src_cache.device() == dst_cache.device(),
56
+ "Source and destination tensors must be on the same device");
57
+ TORCH_CHECK(src_cache.numel() == dst_cache.numel(),
58
+ "Source and destination tensors must have the same number of elements");
59
+ TORCH_CHECK(src_cache.is_contiguous() && dst_cache.is_contiguous(),
60
+ "Both tensors must be contiguous");
61
+
62
+ const uint32_t num_elements = static_cast<uint32_t>(src_cache.numel());
63
+ if (num_elements == 0) {
64
+ return; // Nothing to convert
65
+ }
66
+
67
+ // Determine conversion kernel name
68
+ std::string kernel_name = getConvertKernelName(src_cache.scalar_type(), dst_cache.scalar_type());
69
+
70
+ @autoreleasepool {
71
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
72
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
73
+
74
+ id<MTLDevice> device = stream->device();
75
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
76
+ TORCH_CHECK(cmdBuf, "Failed to get command buffer");
77
+
78
+ // Load Metal library
79
+ std::string moduleDir = getModuleDirectory();
80
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
81
+ NSString *metallibPathStr = [NSString stringWithUTF8String:metallibPath.c_str()];
82
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
83
+ NSError *error = nil;
84
+ id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
85
+ TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ",
86
+ error ? error.localizedDescription.UTF8String : "unknown error");
87
+
88
+ // Create kernel function
89
+ NSString *kernelNameStr = [NSString stringWithUTF8String:kernel_name.c_str()];
90
+ id<MTLFunction> fn = [lib newFunctionWithName:kernelNameStr];
91
+ TORCH_CHECK(fn, "Failed to find Metal kernel function: ", kernel_name);
92
+
93
+ id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:fn error:&error];
94
+ TORCH_CHECK(pso, "Failed to create compute pipeline state: ",
95
+ error ? error.localizedDescription.UTF8String : "unknown error");
96
+
97
+ dispatch_queue_t q = stream->queue();
98
+ dispatch_sync(q, ^{
99
+ id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
100
+ TORCH_CHECK(enc, "Failed to create compute encoder");
101
+
102
+ [enc setComputePipelineState:pso];
103
+
104
+ // Set buffers
105
+ [enc setBuffer:getMTLBufferStorage(src_cache)
106
+ offset:src_cache.storage_offset() * src_cache.element_size()
107
+ atIndex:0];
108
+ [enc setBuffer:getMTLBufferStorage(dst_cache)
109
+ offset:dst_cache.storage_offset() * dst_cache.element_size()
110
+ atIndex:1];
111
+
112
+ // Set scale parameter
113
+ float scale_f32 = static_cast<float>(scale);
114
+ id<MTLBuffer> scaleBuf = [device newBufferWithBytes:&scale_f32
115
+ length:sizeof(float)
116
+ options:MTLResourceStorageModeShared];
117
+ [enc setBuffer:scaleBuf offset:0 atIndex:2];
118
+
119
+ // Set num_elements parameter
120
+ id<MTLBuffer> numElementsBuf = [device newBufferWithBytes:&num_elements
121
+ length:sizeof(uint32_t)
122
+ options:MTLResourceStorageModeShared];
123
+ [enc setBuffer:numElementsBuf offset:0 atIndex:3];
124
+
125
+ // Dispatch threads
126
+ const uint32_t threads_per_threadgroup = std::min<uint32_t>(1024, num_elements);
127
+ const uint32_t threadgroups = (num_elements + threads_per_threadgroup - 1) / threads_per_threadgroup;
128
+
129
+ MTLSize threadsPerThreadgroup = MTLSizeMake(threads_per_threadgroup, 1, 1);
130
+ MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1);
131
+
132
+ [enc dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
133
+ [enc endEncoding];
134
+ });
135
+
136
+ stream->synchronize(at::mps::SyncType::COMMIT);
137
+ }
138
  }
paged-attention-metal/float8.metal ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ // Helpers ------------------------------------------------------------
5
+ static inline uint as_bits(float x) { return as_type<uint>(x); }
6
+ static inline float from_bits(uint b) { return as_type<float>(b); }
7
+
8
+ // -------------------------------------------------------------------
9
+ // FP8 E4M3 (bias = 7)
10
+ // -------------------------------------------------------------------
11
+ inline float fp8_e4m3_to_float(uchar v) {
12
+ const uint s = v >> 7;
13
+ const uint exp = (v >> 3) & 0xF;
14
+ const uint man = v & 0x7;
15
+
16
+ if (exp == 0) { // zero / sub-normal
17
+ if (man == 0)
18
+ return s ? -0.f : 0.f;
19
+ const float m = float(man) / 8.f; // already scaled by 2^-3
20
+ float val = ldexp(m, 1 - 7); // 2^(1-bias) = 2^-6
21
+ return s ? -val : val;
22
+ }
23
+
24
+ if (exp == 0xF) { // Inf / NaN (E4M3FN keeps only NaN)
25
+ if (man != 0)
26
+ return NAN;
27
+ return s ? -INFINITY : INFINITY;
28
+ }
29
+
30
+ const float m = 1.f + float(man) / 8.f;
31
+ float val = ldexp(m, int(exp) - 7);
32
+ return s ? -val : val;
33
+ }
34
+
35
+ // -------------------------------------------------------------------
36
+ // FP8 E5M2 (bias = 15)
37
+ // -------------------------------------------------------------------
38
+ inline float fp8_e5m2_to_float(uchar v) {
39
+ const uint s = v >> 7;
40
+ const uint exp = (v >> 2) & 0x1F;
41
+ const uint man = v & 0x3;
42
+
43
+ if (exp == 0) {
44
+ if (man == 0)
45
+ return s ? -0.f : 0.f;
46
+ const float m = float(man) / 4.f;
47
+ float val = ldexp(m, 1 - 15); // 2^(1-bias) = 2^-14
48
+ return s ? -val : val;
49
+ }
50
+
51
+ if (exp == 0x1F) {
52
+ if (man != 0)
53
+ return NAN;
54
+ return s ? -INFINITY : INFINITY;
55
+ }
56
+
57
+ const float m = 1.f + float(man) / 4.f;
58
+ float val = ldexp(m, int(exp) - 15);
59
+ return s ? -val : val;
60
+ }
61
+
62
+ // -------------------------------------------------------------------
63
+ // Encoding helpers (round-to-nearest-even, gradual under-flow, sat-to-∞)
64
+ // -------------------------------------------------------------------
65
+ namespace detail {
66
+ template <int EXP_BITS, int MAN_BITS, int BIAS>
67
+ inline uchar fp32_to_fp8(float f) {
68
+ const uint bits = as_bits(f);
69
+ const uint s = bits >> 31;
70
+ const uint abs = bits & 0x7FFFFFFF;
71
+
72
+ // NaN propagates, Inf saturates
73
+ if (abs >= 0x7F800000u) {
74
+ return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) |
75
+ (abs != 0x7F800000u));
76
+ }
77
+
78
+ int e = int((abs >> 23) & 0xFF) - 127; // unbiased exponent
79
+ uint m = abs & 0x7FFFFFu; // 23-bit mantissa
80
+ const int EXP_MAX = (1 << EXP_BITS) - 2; // last finite exponent
81
+
82
+ // ---------- Normal path -------------------------------------------------
83
+ int e_fp8 = e + BIAS;
84
+ if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) {
85
+ // round-to-nearest-even
86
+ const int shift = 23 - MAN_BITS;
87
+ uint mant = m >> shift;
88
+ const uint lsb = mant & 1u;
89
+ const uint round = (m >> (shift - 1)) & 1u;
90
+ const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u;
91
+ mant += (round & (sticky | lsb));
92
+ if (mant >> MAN_BITS) { // mantissa overflow
93
+ mant = 0;
94
+ ++e_fp8;
95
+ if (e_fp8 > EXP_MAX)
96
+ return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); // ∞
97
+ }
98
+ return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) |
99
+ (mant & ((1u << MAN_BITS) - 1u)));
100
+ }
101
+
102
+ // ---------- Sub-normal / under-flow ------------------------------------
103
+ if (e_fp8 < 1 - MAN_BITS) // too small -> ±0
104
+ return uchar(s << 7);
105
+
106
+ // shift so that exponent becomes 1
107
+ int rshift = (1 - e_fp8) + (23 - MAN_BITS);
108
+ uint mant = (0x800000u | m); // implicit 1
109
+ uint rounded = (mant + (1u << (rshift - 1))) >> rshift;
110
+ if (rounded == 0)
111
+ return uchar(s << 7); // rounds to zero
112
+
113
+ return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u)));
114
+ }
115
+ } // namespace detail
116
+
117
+ inline uchar float_to_fp8_e4m3(float f) {
118
+ return detail::fp32_to_fp8<4, 3, 7>(f);
119
+ }
120
+ inline uchar float_to_fp8_e5m2(float f) {
121
+ return detail::fp32_to_fp8<5, 2, 15>(f);
122
+ }
paged-attention-metal/paged_attention.mm CHANGED
@@ -28,7 +28,9 @@ static std::string getModuleDirectory() {
28
 
29
  // Helper function to get kernel name based on dtype and parameters
30
  static std::string getKernelName(const std::string &base_name,
31
- torch::ScalarType dtype, int head_size,
 
 
32
  int block_size, int num_threads,
33
  int num_simd_lanes, int partition_size = 0) {
34
  std::string dtype_str;
@@ -46,8 +48,26 @@ static std::string getKernelName(const std::string &base_name,
46
  TORCH_CHECK(false, "Unsupported dtype for paged attention: ", dtype);
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  std::string kernel_name =
50
- base_name + "_" + dtype_str + "_hs" + std::to_string(head_size) + "_bs" +
51
  std::to_string(block_size) + "_nt" + std::to_string(num_threads) +
52
  "_nsl" + std::to_string(num_simd_lanes);
53
 
@@ -106,12 +126,19 @@ void paged_attention_v1(
106
  const bool is_block_sparse = (blocksparse_vert_stride > 1);
107
 
108
  // Validate block sparse is not supported yet
109
- // TODO: support blocksparse, k/v scale.
110
  TORCH_CHECK(
111
  !is_block_sparse,
112
  "Block sparse attention is not yet supported in Metal implementation");
113
- if (kv_cache_dtype != "auto") {
114
- TORCH_CHECK(false, "fp8 is not supported on Metal.");
 
 
 
 
 
 
 
115
  }
116
 
117
  // Validate input tensors
@@ -147,7 +174,7 @@ void paged_attention_v1(
147
 
148
  // Get kernel name - v1 kernels have partition_size=0 in their name
149
  std::string kernel_name =
150
- getKernelName("paged_attention", query.scalar_type(), head_size,
151
  block_size, num_threads, num_simd_lanes, partition_size);
152
 
153
  @autoreleasepool {
@@ -174,6 +201,7 @@ void paged_attention_v1(
174
  type:MTLDataTypeBool
175
  atIndex:10];
176
  [constants setConstantValue:&use_alibi type:MTLDataTypeBool atIndex:20];
 
177
 
178
  NSString *kernelNameStr =
179
  [NSString stringWithUTF8String:kernel_name.c_str()];
@@ -233,6 +261,18 @@ void paged_attention_v1(
233
  offset:value_cache.storage_offset() * value_cache.element_size()
234
  atIndex:buffer_idx++];
235
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  // num_kv_heads
237
  int32_t num_kv_heads_i32 = static_cast<int32_t>(num_kv_heads);
238
  [enc setBytes:&num_kv_heads_i32
@@ -324,13 +364,20 @@ void paged_attention_v2(
324
  const int64_t blocksparse_head_sliding_step) {
325
  const bool is_block_sparse = (blocksparse_vert_stride > 1);
326
 
327
- // TODO: support blocksparse, k/v scale.
328
  // Validate block sparse is not supported yet
329
  TORCH_CHECK(
330
  !is_block_sparse,
331
  "Block sparse attention is not yet supported in Metal implementation");
332
- if (kv_cache_dtype != "auto") {
333
- TORCH_CHECK(false, "fp8 is not supported on Metal.");
 
 
 
 
 
 
 
334
  }
335
 
336
  // Validate input tensors
@@ -365,7 +412,7 @@ void paged_attention_v2(
365
 
366
  // Get kernel names
367
  std::string kernel_name =
368
- getKernelName("paged_attention", query.scalar_type(), head_size,
369
  block_size, num_threads, num_simd_lanes, partition_size);
370
  // Reduce kernel doesn't have block_size in its name
371
  std::string reduce_kernel_name = "paged_attention_v2_reduce";
@@ -427,6 +474,9 @@ void paged_attention_v2(
427
  [mainConstants setConstantValue:&use_alibi
428
  type:MTLDataTypeBool
429
  atIndex:20];
 
 
 
430
 
431
  NSString *kernelNameStr =
432
  [NSString stringWithUTF8String:kernel_name.c_str()];
@@ -485,6 +535,18 @@ void paged_attention_v2(
485
  offset:value_cache.storage_offset() * value_cache.element_size()
486
  atIndex:buffer_idx++];
487
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  // num_kv_heads
489
  int32_t num_kv_heads_i32 = static_cast<int32_t>(num_kv_heads);
490
  [enc setBytes:&num_kv_heads_i32
 
28
 
29
  // Helper function to get kernel name based on dtype and parameters
30
  static std::string getKernelName(const std::string &base_name,
31
+ torch::ScalarType dtype,
32
+ torch::ScalarType cache_dtype,
33
+ int head_size,
34
  int block_size, int num_threads,
35
  int num_simd_lanes, int partition_size = 0) {
36
  std::string dtype_str;
 
48
  TORCH_CHECK(false, "Unsupported dtype for paged attention: ", dtype);
49
  }
50
 
51
+ std::string cache_dtype_str;
52
+ switch (cache_dtype) {
53
+ case torch::kFloat:
54
+ cache_dtype_str = "float";
55
+ break;
56
+ case torch::kHalf:
57
+ cache_dtype_str = "half";
58
+ break;
59
+ case torch::kBFloat16:
60
+ cache_dtype_str = "bfloat16_t";
61
+ break;
62
+ case torch::kUInt8:
63
+ cache_dtype_str = "uchar";
64
+ break;
65
+ default:
66
+ TORCH_CHECK(false, "Unsupported cache dtype for paged attention: ", cache_dtype);
67
+ }
68
+
69
  std::string kernel_name =
70
+ base_name + "_" + dtype_str + "_cache_" + cache_dtype_str + "_hs" + std::to_string(head_size) + "_bs" +
71
  std::to_string(block_size) + "_nt" + std::to_string(num_threads) +
72
  "_nsl" + std::to_string(num_simd_lanes);
73
 
 
126
  const bool is_block_sparse = (blocksparse_vert_stride > 1);
127
 
128
  // Validate block sparse is not supported yet
129
+ // TODO: support blocksparse.
130
  TORCH_CHECK(
131
  !is_block_sparse,
132
  "Block sparse attention is not yet supported in Metal implementation");
133
+
134
+ // Determine cache dtype based on kv_cache_dtype
135
+ torch::ScalarType cache_dtype = key_cache.scalar_type();
136
+ bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3");
137
+ if (use_fp8_scales) {
138
+ TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type");
139
+ TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars");
140
+ TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32,
141
+ "FP8 scales must be float32");
142
  }
143
 
144
  // Validate input tensors
 
174
 
175
  // Get kernel name - v1 kernels have partition_size=0 in their name
176
  std::string kernel_name =
177
+ getKernelName("paged_attention", query.scalar_type(), cache_dtype, head_size,
178
  block_size, num_threads, num_simd_lanes, partition_size);
179
 
180
  @autoreleasepool {
 
201
  type:MTLDataTypeBool
202
  atIndex:10];
203
  [constants setConstantValue:&use_alibi type:MTLDataTypeBool atIndex:20];
204
+ [constants setConstantValue:&use_fp8_scales type:MTLDataTypeBool atIndex:30];
205
 
206
  NSString *kernelNameStr =
207
  [NSString stringWithUTF8String:kernel_name.c_str()];
 
261
  offset:value_cache.storage_offset() * value_cache.element_size()
262
  atIndex:buffer_idx++];
263
 
264
+ // k_scale and v_scale (for FP8)
265
+ if (use_fp8_scales) {
266
+ [enc setBuffer:getMTLBufferStorage(k_scale)
267
+ offset:k_scale.storage_offset() * k_scale.element_size()
268
+ atIndex:buffer_idx++];
269
+ [enc setBuffer:getMTLBufferStorage(v_scale)
270
+ offset:v_scale.storage_offset() * v_scale.element_size()
271
+ atIndex:buffer_idx++];
272
+ } else {
273
+ buffer_idx += 2; // Skip k_scale and v_scale buffer slots
274
+ }
275
+
276
  // num_kv_heads
277
  int32_t num_kv_heads_i32 = static_cast<int32_t>(num_kv_heads);
278
  [enc setBytes:&num_kv_heads_i32
 
364
  const int64_t blocksparse_head_sliding_step) {
365
  const bool is_block_sparse = (blocksparse_vert_stride > 1);
366
 
367
+ // TODO: support blocksparse.
368
  // Validate block sparse is not supported yet
369
  TORCH_CHECK(
370
  !is_block_sparse,
371
  "Block sparse attention is not yet supported in Metal implementation");
372
+
373
+ // Determine cache dtype based on kv_cache_dtype
374
+ torch::ScalarType cache_dtype = key_cache.scalar_type();
375
+ bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3");
376
+ if (use_fp8_scales) {
377
+ TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type");
378
+ TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars");
379
+ TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32,
380
+ "FP8 scales must be float32");
381
  }
382
 
383
  // Validate input tensors
 
412
 
413
  // Get kernel names
414
  std::string kernel_name =
415
+ getKernelName("paged_attention", query.scalar_type(), cache_dtype, head_size,
416
  block_size, num_threads, num_simd_lanes, partition_size);
417
  // Reduce kernel doesn't have block_size in its name
418
  std::string reduce_kernel_name = "paged_attention_v2_reduce";
 
474
  [mainConstants setConstantValue:&use_alibi
475
  type:MTLDataTypeBool
476
  atIndex:20];
477
+ [mainConstants setConstantValue:&use_fp8_scales
478
+ type:MTLDataTypeBool
479
+ atIndex:30];
480
 
481
  NSString *kernelNameStr =
482
  [NSString stringWithUTF8String:kernel_name.c_str()];
 
535
  offset:value_cache.storage_offset() * value_cache.element_size()
536
  atIndex:buffer_idx++];
537
 
538
+ // k_scale and v_scale (for FP8)
539
+ if (use_fp8_scales) {
540
+ [enc setBuffer:getMTLBufferStorage(k_scale)
541
+ offset:k_scale.storage_offset() * k_scale.element_size()
542
+ atIndex:buffer_idx++];
543
+ [enc setBuffer:getMTLBufferStorage(v_scale)
544
+ offset:v_scale.storage_offset() * v_scale.element_size()
545
+ atIndex:buffer_idx++];
546
+ } else {
547
+ buffer_idx += 2; // Skip k_scale and v_scale buffer slots
548
+ }
549
+
550
  // num_kv_heads
551
  int32_t num_kv_heads_i32 = static_cast<int32_t>(num_kv_heads);
552
  [enc setBytes:&num_kv_heads_i32
tests/kernels/test_attention.py CHANGED
@@ -34,7 +34,7 @@ HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
34
  BLOCK_SIZES = [16, 32]
35
  USE_ALIBI = [False, True]
36
  if current_platform.is_mps():
37
- KV_CACHE_DTYPE = ["auto"]
38
  else:
39
  KV_CACHE_DTYPE = ["auto", "fp8"]
40
  SEEDS = [0]
 
34
  BLOCK_SIZES = [16, 32]
35
  USE_ALIBI = [False, True]
36
  if current_platform.is_mps():
37
+ KV_CACHE_DTYPE = ["auto", "fp8"]
38
  else:
39
  KV_CACHE_DTYPE = ["auto", "fp8"]
40
  SEEDS = [0]
tests/kernels/test_cache.py CHANGED
@@ -8,7 +8,7 @@ from paged_attention.platforms import current_platform
8
 
9
  from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
 
11
- COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
12
  DTYPES = [torch.half, torch.bfloat16, torch.float]
13
  NUM_TOKENS = [42] # Arbitrary values for testing
14
  NUM_LAYERS = [1] # Arbitrary values for testing
@@ -28,7 +28,7 @@ else:
28
  DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
29
 
30
  if current_platform.is_mps():
31
- KV_CACHE_DTYPE = ["auto"]
32
  else:
33
  KV_CACHE_DTYPE = ["auto", "fp8"]
34
 
@@ -226,10 +226,10 @@ def test_reshape_and_cache(
226
 
227
  if kv_cache_dtype == "fp8":
228
  torch.testing.assert_close(
229
- result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
230
  )
231
  torch.testing.assert_close(
232
- result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
233
  )
234
  else:
235
  torch.testing.assert_close(key_cache, cloned_key_cache)
@@ -258,6 +258,9 @@ def test_reshape_and_cache_flash(
258
  device: str,
259
  kv_cache_dtype: str,
260
  ) -> None:
 
 
 
261
  current_platform.seed_everything(seed)
262
  torch.set_default_device(device)
263
 
@@ -346,10 +349,10 @@ def test_reshape_and_cache_flash(
346
 
347
  if kv_cache_dtype == "fp8":
348
  torch.testing.assert_close(
349
- result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
350
  )
351
  torch.testing.assert_close(
352
- result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
353
  )
354
  else:
355
  torch.testing.assert_close(key_cache, cloned_key_cache)
@@ -387,8 +390,8 @@ def test_swap_blocks(
387
 
388
  current_platform.seed_everything(seed)
389
 
390
- src_device = device if direction[0] == "cuda" else "cpu"
391
- dst_device = device if direction[1] == "cuda" else "cpu"
392
 
393
  src_blocks = random.sample(range(num_blocks), num_mappings)
394
  # For the same device, mapping must not overlap
@@ -474,8 +477,6 @@ def test_fp8_e4m3_conversion(
474
  seed: int,
475
  device: str,
476
  ) -> None:
477
- if current_platform.is_mps():
478
- pytest.skip()
479
  current_platform.seed_everything(seed)
480
 
481
  low = -224.0
@@ -490,4 +491,60 @@ def test_fp8_e4m3_conversion(
490
  converted_cache = torch.empty_like(cache)
491
  ops.convert_fp8(converted_cache, cache_fp8)
492
 
493
- torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
 
11
+ COPYING_DIRECTION = [("gpu", "cpu"), ("gpu", "gpu"), ("cpu", "gpu")]
12
  DTYPES = [torch.half, torch.bfloat16, torch.float]
13
  NUM_TOKENS = [42] # Arbitrary values for testing
14
  NUM_LAYERS = [1] # Arbitrary values for testing
 
28
  DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
29
 
30
  if current_platform.is_mps():
31
+ KV_CACHE_DTYPE = ["auto", "fp8"]
32
  else:
33
  KV_CACHE_DTYPE = ["auto", "fp8"]
34
 
 
226
 
227
  if kv_cache_dtype == "fp8":
228
  torch.testing.assert_close(
229
+ result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2
230
  )
231
  torch.testing.assert_close(
232
+ result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2
233
  )
234
  else:
235
  torch.testing.assert_close(key_cache, cloned_key_cache)
 
258
  device: str,
259
  kv_cache_dtype: str,
260
  ) -> None:
261
+ # Flash variant doesn't support FP8 on MPS devices yet
262
+ if current_platform.is_mps() and kv_cache_dtype == "fp8":
263
+ pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS")
264
  current_platform.seed_everything(seed)
265
  torch.set_default_device(device)
266
 
 
349
 
350
  if kv_cache_dtype == "fp8":
351
  torch.testing.assert_close(
352
+ result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2
353
  )
354
  torch.testing.assert_close(
355
+ result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2
356
  )
357
  else:
358
  torch.testing.assert_close(key_cache, cloned_key_cache)
 
390
 
391
  current_platform.seed_everything(seed)
392
 
393
+ src_device = device if direction[0] == "gpu" else "cpu"
394
+ dst_device = device if direction[1] == "gpu" else "cpu"
395
 
396
  src_blocks = random.sample(range(num_blocks), num_mappings)
397
  # For the same device, mapping must not overlap
 
477
  seed: int,
478
  device: str,
479
  ) -> None:
 
 
480
  current_platform.seed_everything(seed)
481
 
482
  low = -224.0
 
491
  converted_cache = torch.empty_like(cache)
492
  ops.convert_fp8(converted_cache, cache_fp8)
493
 
494
+ torch.testing.assert_close(cache, converted_cache, atol=0.02, rtol=0.2)
495
+
496
+
497
+ @pytest.mark.parametrize("src_dtype", [torch.float, torch.half, torch.bfloat16, torch.uint8])
498
+ @pytest.mark.parametrize("dst_dtype", [torch.float, torch.half, torch.bfloat16, torch.uint8])
499
+ @pytest.mark.parametrize("scale", [1.0, 0.5, 2.0, 0.1])
500
+ @pytest.mark.parametrize("device", DEVICES)
501
+ @torch.inference_mode()
502
+ def test_convert_fp8_comprehensive(
503
+ src_dtype: torch.dtype,
504
+ dst_dtype: torch.dtype,
505
+ scale: float,
506
+ device: str,
507
+ ) -> None:
508
+ """Test comprehensive FP8 conversion between all supported types"""
509
+ if current_platform.is_mps() and device != "mps:0":
510
+ pytest.skip()
511
+ if not current_platform.is_mps() and device == "mps:0":
512
+ pytest.skip()
513
+
514
+ current_platform.seed_everything(0)
515
+ torch.set_default_device(device)
516
+
517
+ # Create test tensor with reasonable values for FP8 range
518
+ shape = (32, 8, 16, 16) # Small tensor for fast testing
519
+ if src_dtype == torch.uint8:
520
+ # Create FP8 data by converting from float
521
+ src_float = torch.randn(shape, dtype=torch.float, device=device) * 0.1
522
+ src_cache = torch.empty(shape, dtype=torch.uint8, device=device)
523
+ ops.convert_fp8(src_cache, src_float, 1.0, "fp8")
524
+ else:
525
+ # Create source data in range suitable for FP8 conversion
526
+ src_cache = torch.randn(shape, dtype=src_dtype, device=device) * 0.1
527
+
528
+ # Perform conversion
529
+ dst_cache = torch.empty_like(src_cache, dtype=dst_dtype, device=device)
530
+ ops.convert_fp8(dst_cache, src_cache, scale, "fp8")
531
+
532
+ # Verify the tensor was modified (not all zeros)
533
+ assert not torch.allclose(dst_cache.float(), torch.zeros_like(dst_cache.float()))
534
+
535
+ # For round-trip tests (same type), verify approximate equality
536
+ if src_dtype == dst_dtype and scale == 1.0:
537
+ if src_dtype == torch.uint8:
538
+ # FP8 -> FP8 should be identity with scale=1.0
539
+ torch.testing.assert_close(src_cache, dst_cache)
540
+ else:
541
+ # Non-FP8 -> Non-FP8 should be identity with scale=1.0
542
+ torch.testing.assert_close(src_cache, dst_cache, atol=1e-6, rtol=1e-5)
543
+
544
+ # For FP8 round-trip tests (float -> FP8 -> float), verify reasonable approximation
545
+ if src_dtype != torch.uint8 and dst_dtype == torch.uint8 and scale == 1.0:
546
+ # Convert back to verify round-trip accuracy
547
+ roundtrip = torch.empty_like(src_cache, dtype=src_dtype, device=device)
548
+ ops.convert_fp8(roundtrip, dst_cache, 1.0, "fp8")
549
+ # FP8 has limited precision, so use relaxed tolerances
550
+ torch.testing.assert_close(src_cache, roundtrip, atol=0.02, rtol=0.2)