Eric Buehler
commited on
Commit
·
05b1349
1
Parent(s):
6d6d594
Add metal kernels
Browse files- README.md +12 -3
- build.toml +19 -0
- flake.lock +117 -0
- flake.nix +17 -0
- paged-attention-metal/attention/pagedattention.metal +1187 -0
- paged-attention-metal/cache/copy_blocks.metal +50 -0
- paged-attention-metal/cache/reshape_and_cache.metal +74 -0
- paged-attention-metal/paged_attention.mm +117 -0
- paged-attention-metal/utils.metal +0 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/allclose_default.py +14 -0
- tests/kernels/conftest.py +158 -0
- tests/kernels/test_attention.py +418 -0
- tests/kernels/test_cache.py +486 -0
- tests/kernels/utils.py +92 -0
README.md
CHANGED
|
@@ -1,3 +1,12 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- kernel
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## attention
|
| 11 |
+
|
| 12 |
+
Paged attention kernels from [vLLM](https://github.com/vllm-project/).
|
build.toml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "paged_attention"
|
| 3 |
+
|
| 4 |
+
[torch]
|
| 5 |
+
src = [
|
| 6 |
+
"torch-ext/torch_binding.cpp",
|
| 7 |
+
"torch-ext/torch_binding.h"
|
| 8 |
+
]
|
| 9 |
+
|
| 10 |
+
[kernel.activation_metal]
|
| 11 |
+
backend = "metal"
|
| 12 |
+
src = [
|
| 13 |
+
"paged-attention-metal/attention/paged_attention.metal",
|
| 14 |
+
"paged-attention-metal/cache/copy_blocks.metal",
|
| 15 |
+
"paged-attention-metal/cache/reshape_and_cache.metal",
|
| 16 |
+
"paged-attention-metal/utils.metal",
|
| 17 |
+
"paged-attention-metal/paged_attention.mm",
|
| 18 |
+
]
|
| 19 |
+
depends = [ "torch" ]
|
flake.lock
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1733328505,
|
| 6 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-utils": {
|
| 19 |
+
"inputs": {
|
| 20 |
+
"systems": "systems"
|
| 21 |
+
},
|
| 22 |
+
"locked": {
|
| 23 |
+
"lastModified": 1731533236,
|
| 24 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 25 |
+
"owner": "numtide",
|
| 26 |
+
"repo": "flake-utils",
|
| 27 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 28 |
+
"type": "github"
|
| 29 |
+
},
|
| 30 |
+
"original": {
|
| 31 |
+
"owner": "numtide",
|
| 32 |
+
"repo": "flake-utils",
|
| 33 |
+
"type": "github"
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"kernel-builder": {
|
| 37 |
+
"inputs": {
|
| 38 |
+
"flake-compat": "flake-compat",
|
| 39 |
+
"flake-utils": "flake-utils",
|
| 40 |
+
"nixpkgs": "nixpkgs",
|
| 41 |
+
"rocm-nix": "rocm-nix"
|
| 42 |
+
},
|
| 43 |
+
"locked": {
|
| 44 |
+
"lastModified": 1744976941,
|
| 45 |
+
"narHash": "sha256-+csrhVaT6Mj2j1FM7P2BDITvf1Xwj2AKdMm0IKZK340=",
|
| 46 |
+
"owner": "huggingface",
|
| 47 |
+
"repo": "kernel-builder",
|
| 48 |
+
"rev": "0a278c2e9aaf6003a4ec6fe35c7158624762de5a",
|
| 49 |
+
"type": "github"
|
| 50 |
+
},
|
| 51 |
+
"original": {
|
| 52 |
+
"owner": "huggingface",
|
| 53 |
+
"repo": "kernel-builder",
|
| 54 |
+
"type": "github"
|
| 55 |
+
}
|
| 56 |
+
},
|
| 57 |
+
"nixpkgs": {
|
| 58 |
+
"locked": {
|
| 59 |
+
"lastModified": 1743559129,
|
| 60 |
+
"narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=",
|
| 61 |
+
"owner": "nixos",
|
| 62 |
+
"repo": "nixpkgs",
|
| 63 |
+
"rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e",
|
| 64 |
+
"type": "github"
|
| 65 |
+
},
|
| 66 |
+
"original": {
|
| 67 |
+
"owner": "nixos",
|
| 68 |
+
"ref": "nixos-unstable-small",
|
| 69 |
+
"repo": "nixpkgs",
|
| 70 |
+
"type": "github"
|
| 71 |
+
}
|
| 72 |
+
},
|
| 73 |
+
"rocm-nix": {
|
| 74 |
+
"inputs": {
|
| 75 |
+
"nixpkgs": [
|
| 76 |
+
"kernel-builder",
|
| 77 |
+
"nixpkgs"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
"locked": {
|
| 81 |
+
"lastModified": 1743085847,
|
| 82 |
+
"narHash": "sha256-uWG29p+nhZmGRV1LffWwRGjwtPIXeu1F0YTQbXgB+GU=",
|
| 83 |
+
"owner": "huggingface",
|
| 84 |
+
"repo": "rocm-nix",
|
| 85 |
+
"rev": "245cdc9bfb4bfafa818711c5f5e0b889afe1ba39",
|
| 86 |
+
"type": "github"
|
| 87 |
+
},
|
| 88 |
+
"original": {
|
| 89 |
+
"owner": "huggingface",
|
| 90 |
+
"repo": "rocm-nix",
|
| 91 |
+
"type": "github"
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
"root": {
|
| 95 |
+
"inputs": {
|
| 96 |
+
"kernel-builder": "kernel-builder"
|
| 97 |
+
}
|
| 98 |
+
},
|
| 99 |
+
"systems": {
|
| 100 |
+
"locked": {
|
| 101 |
+
"lastModified": 1681028828,
|
| 102 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 103 |
+
"owner": "nix-systems",
|
| 104 |
+
"repo": "default",
|
| 105 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 106 |
+
"type": "github"
|
| 107 |
+
},
|
| 108 |
+
"original": {
|
| 109 |
+
"owner": "nix-systems",
|
| 110 |
+
"repo": "default",
|
| 111 |
+
"type": "github"
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
},
|
| 115 |
+
"root": "root",
|
| 116 |
+
"version": 7
|
| 117 |
+
}
|
flake.nix
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for attention kernels";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs =
|
| 9 |
+
{
|
| 10 |
+
self,
|
| 11 |
+
kernel-builder,
|
| 12 |
+
}:
|
| 13 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
+
path = ./.;
|
| 15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 16 |
+
};
|
| 17 |
+
}
|
paged-attention-metal/attention/pagedattention.metal
ADDED
|
@@ -0,0 +1,1187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Updated from MLX commit has f70764a
|
| 2 |
+
|
| 3 |
+
#include "utils.metal"
|
| 4 |
+
#include <metal_simdgroup>
|
| 5 |
+
#include <metal_stdlib>
|
| 6 |
+
|
| 7 |
+
using namespace metal;
|
| 8 |
+
|
| 9 |
+
// ========================================== Generic vector types
|
| 10 |
+
|
| 11 |
+
// A vector type to store Q, K, V elements.
|
| 12 |
+
template <typename T, int VEC_SIZE> struct Vec {};
|
| 13 |
+
|
| 14 |
+
// A vector type to store FP32 accumulators.
|
| 15 |
+
template <typename T> struct FloatVec {};
|
| 16 |
+
|
| 17 |
+
// Template vector operations.
|
| 18 |
+
template <typename Acc, typename A, typename B> inline Acc mul(A a, B b);
|
| 19 |
+
|
| 20 |
+
template <typename T> inline float sum(T v);
|
| 21 |
+
|
| 22 |
+
template <typename T> inline float dot(T a, T b) {
|
| 23 |
+
return sum(mul<T, T, T>(a, b));
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
template <typename A, typename T> inline float dot(T a, T b) {
|
| 27 |
+
return sum(mul<A, T, T>(a, b));
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// FP32 vector data types.
|
| 31 |
+
struct Float8_ {
|
| 32 |
+
float4 x;
|
| 33 |
+
float4 y;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
template <> struct Vec<float, 1> {
|
| 37 |
+
using Type = float;
|
| 38 |
+
};
|
| 39 |
+
template <> struct Vec<float, 2> {
|
| 40 |
+
using Type = float2;
|
| 41 |
+
};
|
| 42 |
+
template <> struct Vec<float, 4> {
|
| 43 |
+
using Type = float4;
|
| 44 |
+
};
|
| 45 |
+
template <> struct Vec<float, 8> {
|
| 46 |
+
using Type = Float8_;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
template <> struct FloatVec<float> {
|
| 50 |
+
using Type = float;
|
| 51 |
+
};
|
| 52 |
+
template <> struct FloatVec<float2> {
|
| 53 |
+
using Type = float2;
|
| 54 |
+
};
|
| 55 |
+
template <> struct FloatVec<float4> {
|
| 56 |
+
using Type = float4;
|
| 57 |
+
};
|
| 58 |
+
template <> struct FloatVec<Float8_> {
|
| 59 |
+
using Type = Float8_;
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
template <> inline float mul(float a, float b) { return a * b; }
|
| 63 |
+
|
| 64 |
+
template <> inline float2 mul(float2 a, float2 b) { return a * b; }
|
| 65 |
+
|
| 66 |
+
template <> inline float4 mul(float4 a, float4 b) { return a * b; }
|
| 67 |
+
|
| 68 |
+
template <> inline Float8_ mul(Float8_ a, Float8_ b) {
|
| 69 |
+
Float8_ c;
|
| 70 |
+
c.x = a.x * b.x;
|
| 71 |
+
c.y = a.y * b.y;
|
| 72 |
+
return c;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
template <> inline float sum(float a) { return a; }
|
| 76 |
+
|
| 77 |
+
template <> inline float sum(float2 a) { return a.x + a.y; }
|
| 78 |
+
|
| 79 |
+
template <> inline float sum(float4 a) { return a.x + a.y + a.z + a.w; }
|
| 80 |
+
|
| 81 |
+
template <> inline float sum(Float8_ a) { return sum(a.x) + sum(a.y); }
|
| 82 |
+
|
| 83 |
+
inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) {
|
| 84 |
+
Float8_ res;
|
| 85 |
+
res.x = fma(a.x, b.x, c.x);
|
| 86 |
+
res.y = fma(a.y, b.y, c.y);
|
| 87 |
+
return res;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
inline void from_float(thread float &dst, float src) { dst = src; }
|
| 91 |
+
inline void from_float(thread float2 &dst, float2 src) { dst = src; }
|
| 92 |
+
inline void from_float(thread float4 &dst, float4 src) { dst = src; }
|
| 93 |
+
inline void from_float(thread Float8_ &dst, Float8_ src) { dst = src; }
|
| 94 |
+
|
| 95 |
+
// BF16 vector data types.
|
| 96 |
+
// #if defined(__HAVE_BFLOAT__)
|
| 97 |
+
|
| 98 |
+
// struct Bfloat8_ {
|
| 99 |
+
// bfloat4 x;
|
| 100 |
+
// bfloat4 y;
|
| 101 |
+
// };
|
| 102 |
+
|
| 103 |
+
// template<>
|
| 104 |
+
// struct Vec<bfloat, 1> {
|
| 105 |
+
// using Type = bfloat;
|
| 106 |
+
// };
|
| 107 |
+
// template<>
|
| 108 |
+
// struct Vec<bfloat, 2> {
|
| 109 |
+
// using Type = bfloat2;
|
| 110 |
+
// };
|
| 111 |
+
// template<>
|
| 112 |
+
// struct Vec<bfloat, 4> {
|
| 113 |
+
// using Type = bfloat4;
|
| 114 |
+
// };
|
| 115 |
+
// template<>
|
| 116 |
+
// struct Vec<bfloat, 8> {
|
| 117 |
+
// using Type = Bfloat8_;
|
| 118 |
+
// };
|
| 119 |
+
|
| 120 |
+
// template<>
|
| 121 |
+
// struct FloatVec<bfloat> {
|
| 122 |
+
// using Type = float;
|
| 123 |
+
// };
|
| 124 |
+
// template<>
|
| 125 |
+
// struct FloatVec<bfloat2> {
|
| 126 |
+
// using Type = float2;
|
| 127 |
+
// };
|
| 128 |
+
// template<>
|
| 129 |
+
// struct FloatVec<bfloat4> {
|
| 130 |
+
// using Type = float4;
|
| 131 |
+
// };
|
| 132 |
+
// template<>
|
| 133 |
+
// struct FloatVec<Bfloat8_> {
|
| 134 |
+
// using Type = Float8_;
|
| 135 |
+
// };
|
| 136 |
+
|
| 137 |
+
// template<>
|
| 138 |
+
// inline float mul(bfloat a, bfloat b) {
|
| 139 |
+
// return (float)a * (float)b;
|
| 140 |
+
// }
|
| 141 |
+
// template<>
|
| 142 |
+
// inline bfloat mul(bfloat a, bfloat b) {
|
| 143 |
+
// return a*b;
|
| 144 |
+
// }
|
| 145 |
+
|
| 146 |
+
// template<>
|
| 147 |
+
// inline float2 mul(bfloat2 a, bfloat2 b) {
|
| 148 |
+
// return (float2)a * (float2)b;
|
| 149 |
+
// }
|
| 150 |
+
// template<>
|
| 151 |
+
// inline bfloat2 mul(bfloat2 a, bfloat2 b) {
|
| 152 |
+
// return a * b;
|
| 153 |
+
// }
|
| 154 |
+
|
| 155 |
+
// template<>
|
| 156 |
+
// inline float4 mul(bfloat4 a, bfloat4 b) {
|
| 157 |
+
// return (float4)a * (float4)b;
|
| 158 |
+
// }
|
| 159 |
+
// template<>
|
| 160 |
+
// inline bfloat4 mul(bfloat4 a, bfloat4 b) {
|
| 161 |
+
// return a * b;
|
| 162 |
+
// }
|
| 163 |
+
|
| 164 |
+
// template<>
|
| 165 |
+
// inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) {
|
| 166 |
+
// Float8_ c;
|
| 167 |
+
// c.x = mul<float4, bfloat4, bfloat4>(a.x, b.x);
|
| 168 |
+
// c.y = mul<float4, bfloat4, bfloat4>(a.y, b.y);
|
| 169 |
+
// return c;
|
| 170 |
+
// }
|
| 171 |
+
// template<>
|
| 172 |
+
// inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) {
|
| 173 |
+
// Bfloat8_ c;
|
| 174 |
+
// c.x = mul<bfloat4, bfloat4, bfloat4>(a.x, b.x);
|
| 175 |
+
// c.y = mul<bfloat4, bfloat4, bfloat4>(a.y, b.y);
|
| 176 |
+
// return c;
|
| 177 |
+
// }
|
| 178 |
+
|
| 179 |
+
// template<>
|
| 180 |
+
// inline float sum(bfloat a) {
|
| 181 |
+
// return (float)a;
|
| 182 |
+
// }
|
| 183 |
+
|
| 184 |
+
// template<>
|
| 185 |
+
// inline float sum(bfloat2 a) {
|
| 186 |
+
// return (float)a.x + (float)a.y;
|
| 187 |
+
// }
|
| 188 |
+
|
| 189 |
+
// template<>
|
| 190 |
+
// inline float sum(bfloat4 a) {
|
| 191 |
+
// return sum(a.x) + sum(a.y);
|
| 192 |
+
// }
|
| 193 |
+
|
| 194 |
+
// template<>
|
| 195 |
+
// inline float sum(Bfloat8_ a) {
|
| 196 |
+
// return sum(a.x) + sum(a.y);
|
| 197 |
+
// }
|
| 198 |
+
|
| 199 |
+
// inline float fma(bfloat a, bfloat b, float c) {
|
| 200 |
+
// return (float)a * (float)b + c;
|
| 201 |
+
// }
|
| 202 |
+
|
| 203 |
+
// inline float2 fma(bfloat2 a, bfloat2 b, float2 c) {
|
| 204 |
+
// return (float2)a * (float2)b + c;
|
| 205 |
+
// }
|
| 206 |
+
|
| 207 |
+
// inline float4 fma(bfloat4 a, bfloat4 b, float4 c) {
|
| 208 |
+
// return (float4)a * (float4)b + c;
|
| 209 |
+
// }
|
| 210 |
+
|
| 211 |
+
// inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) {
|
| 212 |
+
// Float8_ res;
|
| 213 |
+
// res.x = fma((float4)a.x, (float4)b.x, (float4)c.x);
|
| 214 |
+
// res.y = fma((float4)a.y, (float4)b.y, (float4)c.y);
|
| 215 |
+
// return res;
|
| 216 |
+
// }
|
| 217 |
+
// inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) {
|
| 218 |
+
// Bfloat8_ res;
|
| 219 |
+
// res.x = (bfloat4)fma((float4)a.x, (float4)b.x, (float4)c.x);
|
| 220 |
+
// res.y = (bfloat4)fma((float4)a.y, (float4)b.x, (float4)c.y);
|
| 221 |
+
// return c;
|
| 222 |
+
// }
|
| 223 |
+
|
| 224 |
+
// inline void from_float(thread bfloat& dst, float src) {
|
| 225 |
+
// dst = static_cast<bfloat>(src);
|
| 226 |
+
// }
|
| 227 |
+
// inline void from_float(thread bfloat2& dst, float2 src) {
|
| 228 |
+
// dst.x = static_cast<bfloat>(src.x);
|
| 229 |
+
// dst.y = static_cast<bfloat>(src.y);
|
| 230 |
+
// }
|
| 231 |
+
// inline void from_float(thread bfloat4& dst, float4 src) {
|
| 232 |
+
// dst.x = static_cast<bfloat>(src.x);
|
| 233 |
+
// dst.y = static_cast<bfloat>(src.y);
|
| 234 |
+
// dst.z = static_cast<bfloat>(src.z);
|
| 235 |
+
// dst.w = static_cast<bfloat>(src.w);
|
| 236 |
+
// }
|
| 237 |
+
// inline void from_float(thread Bfloat8_& dst, Float8_ src) {
|
| 238 |
+
// bfloat4 x;
|
| 239 |
+
// bfloat4 y;
|
| 240 |
+
// from_float(x, src.x);
|
| 241 |
+
// from_float(y, src.y);
|
| 242 |
+
// dst.x = x;
|
| 243 |
+
// dst.y = y;
|
| 244 |
+
// }
|
| 245 |
+
|
| 246 |
+
// #else
|
| 247 |
+
|
| 248 |
+
struct Bfloat2_ {
|
| 249 |
+
bfloat16_t x;
|
| 250 |
+
bfloat16_t y;
|
| 251 |
+
};
|
| 252 |
+
|
| 253 |
+
struct Bfloat4_ {
|
| 254 |
+
Bfloat2_ x;
|
| 255 |
+
Bfloat2_ y;
|
| 256 |
+
};
|
| 257 |
+
|
| 258 |
+
struct Bfloat8_ {
|
| 259 |
+
Bfloat4_ x;
|
| 260 |
+
Bfloat4_ y;
|
| 261 |
+
};
|
| 262 |
+
|
| 263 |
+
template <> struct Vec<bfloat16_t, 1> {
|
| 264 |
+
using Type = bfloat16_t;
|
| 265 |
+
};
|
| 266 |
+
template <> struct Vec<bfloat16_t, 2> {
|
| 267 |
+
using Type = Bfloat2_;
|
| 268 |
+
};
|
| 269 |
+
template <> struct Vec<bfloat16_t, 4> {
|
| 270 |
+
using Type = Bfloat4_;
|
| 271 |
+
};
|
| 272 |
+
template <> struct Vec<bfloat16_t, 8> {
|
| 273 |
+
using Type = Bfloat8_;
|
| 274 |
+
};
|
| 275 |
+
|
| 276 |
+
template <> struct FloatVec<bfloat16_t> {
|
| 277 |
+
using Type = float;
|
| 278 |
+
};
|
| 279 |
+
template <> struct FloatVec<Bfloat2_> {
|
| 280 |
+
using Type = float2;
|
| 281 |
+
};
|
| 282 |
+
template <> struct FloatVec<Bfloat4_> {
|
| 283 |
+
using Type = float4;
|
| 284 |
+
};
|
| 285 |
+
template <> struct FloatVec<Bfloat8_> {
|
| 286 |
+
using Type = Float8_;
|
| 287 |
+
};
|
| 288 |
+
|
| 289 |
+
template <> inline float mul(bfloat16_t a, bfloat16_t b) {
|
| 290 |
+
return (float)a * (float)b;
|
| 291 |
+
}
|
| 292 |
+
template <> inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { return a * b; }
|
| 293 |
+
|
| 294 |
+
template <> inline float2 mul(Bfloat2_ a, Bfloat2_ b) {
|
| 295 |
+
float2 a_f((float)a.x, (float)a.y);
|
| 296 |
+
float2 b_f((float)b.x, (float)b.y);
|
| 297 |
+
return a_f * b_f;
|
| 298 |
+
}
|
| 299 |
+
template <> inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) {
|
| 300 |
+
Bfloat2_ c;
|
| 301 |
+
c.x = a.x * b.x;
|
| 302 |
+
c.y = a.y * b.y;
|
| 303 |
+
return c;
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
template <> inline float4 mul(Bfloat4_ a, Bfloat4_ b) {
|
| 307 |
+
float2 x = mul<float2, Bfloat2_, Bfloat2_>(a.x, b.x);
|
| 308 |
+
float2 y = mul<float2, Bfloat2_, Bfloat2_>(a.y, b.y);
|
| 309 |
+
float4 c;
|
| 310 |
+
c.x = x.x;
|
| 311 |
+
c.y = x.y;
|
| 312 |
+
c.z = y.x;
|
| 313 |
+
c.w = y.y;
|
| 314 |
+
return c;
|
| 315 |
+
}
|
| 316 |
+
template <> inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) {
|
| 317 |
+
Bfloat4_ c;
|
| 318 |
+
c.x = mul<Bfloat2_, Bfloat2_, Bfloat2_>(a.x, b.x);
|
| 319 |
+
c.y = mul<Bfloat2_, Bfloat2_, Bfloat2_>(a.y, b.y);
|
| 320 |
+
return c;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
template <> inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) {
|
| 324 |
+
Float8_ c;
|
| 325 |
+
c.x = mul<float4, Bfloat4_, Bfloat4_>(a.x, b.x);
|
| 326 |
+
c.y = mul<float4, Bfloat4_, Bfloat4_>(a.y, b.y);
|
| 327 |
+
return c;
|
| 328 |
+
}
|
| 329 |
+
template <> inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) {
|
| 330 |
+
Bfloat8_ c;
|
| 331 |
+
c.x = mul<Bfloat4_, Bfloat4_, Bfloat4_>(a.x, b.x);
|
| 332 |
+
c.y = mul<Bfloat4_, Bfloat4_, Bfloat4_>(a.y, b.y);
|
| 333 |
+
return c;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
template <> inline float sum(bfloat16_t a) { return (float)a; }
|
| 337 |
+
|
| 338 |
+
template <> inline float sum(Bfloat2_ a) { return (float)a.x + (float)a.y; }
|
| 339 |
+
|
| 340 |
+
template <> inline float sum(Bfloat4_ a) { return sum(a.x) + sum(a.y); }
|
| 341 |
+
|
| 342 |
+
template <> inline float sum(Bfloat8_ a) { return sum(a.x) + sum(a.y); }
|
| 343 |
+
|
| 344 |
+
inline float fma(bfloat16_t a, bfloat16_t b, float c) {
|
| 345 |
+
return (float)a * (float)b + c;
|
| 346 |
+
}
|
| 347 |
+
inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) {
|
| 348 |
+
return a * b + c;
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) {
|
| 352 |
+
float2 a_f((float)a.x, (float)a.y);
|
| 353 |
+
float2 b_f((float)b.x, (float)b.y);
|
| 354 |
+
return a_f * b_f + c;
|
| 355 |
+
}
|
| 356 |
+
inline Bfloat2_ fma(Bfloat2_ a, Bfloat2_ b, Bfloat2_ c) {
|
| 357 |
+
Bfloat2_ res;
|
| 358 |
+
res.x = a.x * b.x + c.x;
|
| 359 |
+
res.y = a.y * b.y + c.y;
|
| 360 |
+
return res;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
inline float4 fma(Bfloat4_ a, Bfloat4_ b, float4 c) {
|
| 364 |
+
float4 res;
|
| 365 |
+
res.x = fma(a.x.x, b.x.x, c.x);
|
| 366 |
+
res.y = fma(a.x.y, b.x.y, c.y);
|
| 367 |
+
res.z = fma(a.y.x, b.y.x, c.z);
|
| 368 |
+
res.w = fma(a.y.y, b.y.y, c.w);
|
| 369 |
+
return res;
|
| 370 |
+
}
|
| 371 |
+
inline Bfloat4_ fma(Bfloat4_ a, Bfloat4_ b, Bfloat4_ c) {
|
| 372 |
+
Bfloat4_ res;
|
| 373 |
+
res.x = fma(a.x, b.x, c.x);
|
| 374 |
+
res.y = fma(a.y, b.y, c.y);
|
| 375 |
+
return res;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) {
|
| 379 |
+
float4 x = fma(a.x, b.x, c.x);
|
| 380 |
+
float4 y = fma(a.y, b.y, c.y);
|
| 381 |
+
Float8_ res;
|
| 382 |
+
res.x = x;
|
| 383 |
+
res.y = y;
|
| 384 |
+
return res;
|
| 385 |
+
}
|
| 386 |
+
inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) {
|
| 387 |
+
Bfloat8_ res;
|
| 388 |
+
res.x = fma(a.x, b.x, c.x);
|
| 389 |
+
res.y = fma(a.y, b.y, c.y);
|
| 390 |
+
return res;
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
inline void from_float(thread bfloat16_t &dst, float src) {
|
| 394 |
+
dst = static_cast<bfloat16_t>(src);
|
| 395 |
+
}
|
| 396 |
+
inline void from_float(thread Bfloat2_ &dst, float2 src) {
|
| 397 |
+
dst.x = static_cast<bfloat16_t>(src.x);
|
| 398 |
+
dst.y = static_cast<bfloat16_t>(src.y);
|
| 399 |
+
}
|
| 400 |
+
inline void from_float(thread Bfloat4_ &dst, float4 src) {
|
| 401 |
+
dst.x.x = static_cast<bfloat16_t>(src.x);
|
| 402 |
+
dst.x.y = static_cast<bfloat16_t>(src.y);
|
| 403 |
+
dst.y.x = static_cast<bfloat16_t>(src.z);
|
| 404 |
+
dst.y.y = static_cast<bfloat16_t>(src.w);
|
| 405 |
+
}
|
| 406 |
+
inline void from_float(thread Bfloat8_ &dst, Float8_ src) {
|
| 407 |
+
Bfloat4_ x;
|
| 408 |
+
Bfloat4_ y;
|
| 409 |
+
from_float(x, src.x);
|
| 410 |
+
from_float(y, src.y);
|
| 411 |
+
dst.x = x;
|
| 412 |
+
dst.y = y;
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
// #endif
|
| 416 |
+
|
| 417 |
+
// FP16 vector data types.
|
| 418 |
+
struct Half8_ {
|
| 419 |
+
half4 x;
|
| 420 |
+
half4 y;
|
| 421 |
+
};
|
| 422 |
+
|
| 423 |
+
template <> struct Vec<half, 1> {
|
| 424 |
+
using Type = half;
|
| 425 |
+
};
|
| 426 |
+
template <> struct Vec<half, 2> {
|
| 427 |
+
using Type = half2;
|
| 428 |
+
};
|
| 429 |
+
template <> struct Vec<half, 4> {
|
| 430 |
+
using Type = half4;
|
| 431 |
+
};
|
| 432 |
+
template <> struct Vec<half, 8> {
|
| 433 |
+
using Type = Half8_;
|
| 434 |
+
};
|
| 435 |
+
|
| 436 |
+
template <> struct FloatVec<half> {
|
| 437 |
+
using Type = float;
|
| 438 |
+
};
|
| 439 |
+
template <> struct FloatVec<half2> {
|
| 440 |
+
using Type = float2;
|
| 441 |
+
};
|
| 442 |
+
template <> struct FloatVec<half4> {
|
| 443 |
+
using Type = float4;
|
| 444 |
+
};
|
| 445 |
+
template <> struct FloatVec<Half8_> {
|
| 446 |
+
using Type = Float8_;
|
| 447 |
+
};
|
| 448 |
+
|
| 449 |
+
template <> inline float mul(half a, half b) { return (float)a * (float)b; }
|
| 450 |
+
template <> inline half mul(half a, half b) { return a * b; }
|
| 451 |
+
|
| 452 |
+
template <> inline float2 mul(half2 a, half2 b) {
|
| 453 |
+
return (float2)a * (float2)b;
|
| 454 |
+
}
|
| 455 |
+
template <> inline half2 mul(half2 a, half2 b) { return a * b; }
|
| 456 |
+
|
| 457 |
+
template <> inline float4 mul(half4 a, half4 b) {
|
| 458 |
+
return (float4)a * (float4)b;
|
| 459 |
+
}
|
| 460 |
+
template <> inline half4 mul(half4 a, half4 b) { return a * b; }
|
| 461 |
+
|
| 462 |
+
template <> inline Float8_ mul(Half8_ a, Half8_ b) {
|
| 463 |
+
float4 x = mul<float4, half4, half4>(a.x, b.x);
|
| 464 |
+
float4 y = mul<float4, half4, half4>(a.y, b.y);
|
| 465 |
+
Float8_ c;
|
| 466 |
+
c.x = x;
|
| 467 |
+
c.y = y;
|
| 468 |
+
return c;
|
| 469 |
+
}
|
| 470 |
+
template <> inline Half8_ mul(Half8_ a, Half8_ b) {
|
| 471 |
+
Half8_ c;
|
| 472 |
+
c.x = mul<half4, half4, half4>(a.x, b.x);
|
| 473 |
+
c.y = mul<half4, half4, half4>(a.y, b.y);
|
| 474 |
+
return c;
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
template <> inline float sum(half a) { return (float)a; }
|
| 478 |
+
|
| 479 |
+
template <> inline float sum(half2 a) { return (float)a.x + (float)a.y; }
|
| 480 |
+
|
| 481 |
+
template <> inline float sum(half4 a) { return a.x + a.y + a.z + a.w; }
|
| 482 |
+
|
| 483 |
+
template <> inline float sum(Half8_ a) { return sum(a.x) + sum(a.y); }
|
| 484 |
+
|
| 485 |
+
inline float fma(half a, half b, float c) { return (float)a * (float)b + c; }
|
| 486 |
+
|
| 487 |
+
inline float2 fma(half2 a, half2 b, float2 c) {
|
| 488 |
+
return (float2)a * (float2)b + c;
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
inline float4 fma(half4 a, half4 b, float4 c) {
|
| 492 |
+
return (float4)a * (float4)b + c;
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
inline Float8_ fma(Half8_ a, Half8_ b, Float8_ c) {
|
| 496 |
+
float4 x = fma(a.x, b.x, c.x);
|
| 497 |
+
float4 y = fma(a.y, b.y, c.y);
|
| 498 |
+
Float8_ res;
|
| 499 |
+
res.x = x;
|
| 500 |
+
res.y = y;
|
| 501 |
+
return res;
|
| 502 |
+
}
|
| 503 |
+
inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) {
|
| 504 |
+
Half8_ res;
|
| 505 |
+
res.x = fma(a.x, b.x, c.x);
|
| 506 |
+
res.y = fma(a.y, b.y, c.y);
|
| 507 |
+
return res;
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
inline void from_float(thread half &dst, float src) {
|
| 511 |
+
dst = static_cast<half>(src);
|
| 512 |
+
}
|
| 513 |
+
inline void from_float(thread half2 &dst, float2 src) {
|
| 514 |
+
dst.x = static_cast<half>(src.x);
|
| 515 |
+
dst.y = static_cast<half>(src.y);
|
| 516 |
+
}
|
| 517 |
+
inline void from_float(thread half4 &dst, float4 src) {
|
| 518 |
+
dst.x = static_cast<half>(src.x);
|
| 519 |
+
dst.y = static_cast<half>(src.y);
|
| 520 |
+
dst.z = static_cast<half>(src.z);
|
| 521 |
+
dst.w = static_cast<half>(src.w);
|
| 522 |
+
}
|
| 523 |
+
inline void from_float(thread Half8_ &dst, Float8_ src) {
|
| 524 |
+
half4 x;
|
| 525 |
+
half4 y;
|
| 526 |
+
from_float(x, src.x);
|
| 527 |
+
from_float(y, src.y);
|
| 528 |
+
dst.x = x;
|
| 529 |
+
dst.y = y;
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
// ========================================== Dot product utilities
|
| 533 |
+
|
| 534 |
+
// TODO(EricLBuehler): optimize with vectorization
|
| 535 |
+
template <int THREAD_GROUP_SIZE, typename Vec, int N>
|
| 536 |
+
inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) {
|
| 537 |
+
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
| 538 |
+
using A_vec = typename FloatVec<Vec>::Type;
|
| 539 |
+
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
| 540 |
+
#pragma unroll
|
| 541 |
+
for (int ii = 1; ii < N; ++ii) {
|
| 542 |
+
qk_vec = fma(q[ii], k[ii], qk_vec);
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
// Finalize the reduction across lanes.
|
| 546 |
+
float qk = sum(qk_vec);
|
| 547 |
+
#pragma unroll
|
| 548 |
+
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
| 549 |
+
qk += simd_shuffle_xor(qk, mask);
|
| 550 |
+
}
|
| 551 |
+
return qk;
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
template <typename T, int THREAD_GROUP_SIZE> struct Qk_dot {
|
| 555 |
+
template <typename Vec, int N>
|
| 556 |
+
static inline float dot(const threadgroup Vec (&q)[N],
|
| 557 |
+
const thread Vec (&k)[N]) {
|
| 558 |
+
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
| 559 |
+
}
|
| 560 |
+
};
|
| 561 |
+
|
| 562 |
+
// ========================================== Block sum utility
|
| 563 |
+
|
| 564 |
+
// Utility function for attention softmax.
|
| 565 |
+
template <int NUM_WARPS, int NUM_SIMD_LANES>
|
| 566 |
+
inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid,
|
| 567 |
+
uint simd_lid) {
|
| 568 |
+
// Compute the sum per simdgroup.
|
| 569 |
+
#pragma unroll
|
| 570 |
+
for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) {
|
| 571 |
+
sum += simd_shuffle_xor(sum, mask);
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
// Simd leaders store the data to shared memory.
|
| 575 |
+
if (simd_lid == 0) {
|
| 576 |
+
red_smem[simd_tid] = sum;
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
// Make sure the data is in shared memory.
|
| 580 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 581 |
+
|
| 582 |
+
// The warps compute the final sums.
|
| 583 |
+
if (simd_lid < NUM_WARPS) {
|
| 584 |
+
sum = red_smem[simd_lid];
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
// Parallel reduction inside the simd group.
|
| 588 |
+
#pragma unroll
|
| 589 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
| 590 |
+
sum += simd_shuffle_xor(sum, mask);
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
// Broadcast to other threads.
|
| 594 |
+
return simd_shuffle(sum, 0);
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
// ========================================== Paged Attention kernel
|
| 598 |
+
|
| 599 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
| 600 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
| 601 |
+
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
| 602 |
+
|
| 603 |
+
constant bool use_partitioning [[function_constant(10)]];
|
| 604 |
+
constant bool use_alibi [[function_constant(20)]];
|
| 605 |
+
|
| 606 |
+
template <typename T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
|
| 607 |
+
int NUM_SIMD_LANES, int PARTITION_SIZE = 0>
|
| 608 |
+
[[kernel]] void paged_attention(
|
| 609 |
+
device float *exp_sums
|
| 610 |
+
[[buffer(0), function_constant(use_partitioning)]], // [num_seqs, num_heads,
|
| 611 |
+
// max_num_partitions]
|
| 612 |
+
device float *max_logits
|
| 613 |
+
[[buffer(1), function_constant(use_partitioning)]], // [num_seqs, num_heads,
|
| 614 |
+
// max_num_partitions]
|
| 615 |
+
device T *out
|
| 616 |
+
[[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size]
|
| 617 |
+
device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size]
|
| 618 |
+
device const T *k_cache
|
| 619 |
+
[[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
| 620 |
+
device const T *v_cache
|
| 621 |
+
[[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size]
|
| 622 |
+
const constant int &num_kv_heads [[buffer(6)]], // [num_heads]
|
| 623 |
+
const constant float &scale [[buffer(7)]],
|
| 624 |
+
const constant float &softcapping [[buffer(8)]],
|
| 625 |
+
device const uint32_t *block_tables
|
| 626 |
+
[[buffer(9)]], // [num_seqs, max_num_blocks_per_seq]
|
| 627 |
+
device const uint32_t *context_lens [[buffer(10)]], // [num_seqs]
|
| 628 |
+
const constant int &max_num_blocks_per_seq [[buffer(11)]],
|
| 629 |
+
device const float *alibi_slopes
|
| 630 |
+
[[buffer(12), function_constant(use_alibi)]], // [num_heads]
|
| 631 |
+
const constant int &q_stride [[buffer(13)]],
|
| 632 |
+
const constant int &kv_block_stride [[buffer(14)]],
|
| 633 |
+
const constant int &kv_head_stride [[buffer(15)]],
|
| 634 |
+
threadgroup char *shared_mem [[threadgroup(0)]],
|
| 635 |
+
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
|
| 636 |
+
uint3 threadgroups_per_grid [[threadgroups_per_grid]],
|
| 637 |
+
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
|
| 638 |
+
uint simd_tid [[simdgroup_index_in_threadgroup]],
|
| 639 |
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
| 640 |
+
const int seq_idx = threadgroup_position_in_grid.y;
|
| 641 |
+
const int partition_idx = threadgroup_position_in_grid.z;
|
| 642 |
+
const int max_num_partitions = threadgroups_per_grid.z;
|
| 643 |
+
const int thread_idx = thread_position_in_threadgroup.x;
|
| 644 |
+
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
| 645 |
+
const uint32_t context_len = context_lens[seq_idx];
|
| 646 |
+
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
|
| 647 |
+
// No work to do. Terminate the thread block.
|
| 648 |
+
return;
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
| 652 |
+
const int num_blocks_per_partition =
|
| 653 |
+
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
|
| 654 |
+
|
| 655 |
+
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
| 656 |
+
const int start_block_idx =
|
| 657 |
+
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
| 658 |
+
const int end_block_idx =
|
| 659 |
+
MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
|
| 660 |
+
const int num_blocks = end_block_idx - start_block_idx;
|
| 661 |
+
|
| 662 |
+
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
| 663 |
+
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
| 664 |
+
const int end_token_idx =
|
| 665 |
+
MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
|
| 666 |
+
const int num_tokens = end_token_idx - start_token_idx;
|
| 667 |
+
|
| 668 |
+
constexpr int THREAD_GROUP_SIZE = MAX(NUM_SIMD_LANES / BLOCK_SIZE, 1);
|
| 669 |
+
constexpr int NUM_THREAD_GROUPS =
|
| 670 |
+
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
|
| 671 |
+
// divides NUM_THREADS
|
| 672 |
+
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
| 673 |
+
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
|
| 674 |
+
DIVIDE_ROUND_UP(BLOCK_SIZE, NUM_SIMD_LANES);
|
| 675 |
+
constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES;
|
| 676 |
+
const int warp_idx = simd_tid;
|
| 677 |
+
const int lane = simd_lid;
|
| 678 |
+
|
| 679 |
+
const int head_idx = threadgroup_position_in_grid.x;
|
| 680 |
+
const int num_heads = threadgroups_per_grid.x;
|
| 681 |
+
const int num_queries_per_kv = num_heads / num_kv_heads;
|
| 682 |
+
const int kv_head_idx = head_idx / num_queries_per_kv;
|
| 683 |
+
const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx];
|
| 684 |
+
|
| 685 |
+
// A vector type to store a part of a key or a query.
|
| 686 |
+
// The vector size is configured in such a way that the threads in a thread
|
| 687 |
+
// group fetch or compute 16 bytes at a time. For example, if the size of a
|
| 688 |
+
// thread group is 4 and the data type is half, then the vector size is 16 /
|
| 689 |
+
// (4 * sizeof(half)) == 2.
|
| 690 |
+
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1);
|
| 691 |
+
using K_vec = typename Vec<T, VEC_SIZE>::Type;
|
| 692 |
+
using Q_vec = typename Vec<T, VEC_SIZE>::Type;
|
| 693 |
+
|
| 694 |
+
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
| 695 |
+
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
| 696 |
+
|
| 697 |
+
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
| 698 |
+
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
| 699 |
+
|
| 700 |
+
// Load the query to registers.
|
| 701 |
+
// Each thread in a thread group has a different part of the query.
|
| 702 |
+
// For example, if the thread group size is 4, then the first thread in the
|
| 703 |
+
// group has 0, 4, 8, ... th vectors of the query, and the second thread has
|
| 704 |
+
// 1, 5, 9, ... th vectors of the query, and so on.
|
| 705 |
+
const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
| 706 |
+
threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
| 707 |
+
#pragma unroll
|
| 708 |
+
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
|
| 709 |
+
i += NUM_THREAD_GROUPS) {
|
| 710 |
+
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
| 711 |
+
q_vecs[thread_group_offset][i] =
|
| 712 |
+
*reinterpret_cast<const device Q_vec *>(q_ptr + vec_idx * VEC_SIZE);
|
| 713 |
+
}
|
| 714 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 715 |
+
|
| 716 |
+
// Use fp32 on softmax logits for better accuracy
|
| 717 |
+
threadgroup float *logits = reinterpret_cast<threadgroup float *>(shared_mem);
|
| 718 |
+
// Workspace for reduction
|
| 719 |
+
threadgroup float red_smem[2 * NUM_WARPS];
|
| 720 |
+
|
| 721 |
+
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
| 722 |
+
// Each thread group fetches x elements from the key at a time.
|
| 723 |
+
constexpr int x = 16 / sizeof(T);
|
| 724 |
+
float qk_max = -FLT_MAX;
|
| 725 |
+
|
| 726 |
+
// Iterate over the key blocks.
|
| 727 |
+
// Each warp fetches a block of keys for each iteration.
|
| 728 |
+
// Each thread group in a warp fetches a key from the block, and computes
|
| 729 |
+
// dot product with the query.
|
| 730 |
+
const device uint32_t *block_table =
|
| 731 |
+
block_tables + seq_idx * max_num_blocks_per_seq;
|
| 732 |
+
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
| 733 |
+
block_idx += NUM_WARPS) {
|
| 734 |
+
// NOTE: The block number is stored in int32. However, we cast it to int64
|
| 735 |
+
// because int32 can lead to overflow when this variable is multiplied by
|
| 736 |
+
// large numbers (e.g., kv_block_stride).
|
| 737 |
+
const int64_t physical_block_number =
|
| 738 |
+
static_cast<int64_t>(block_table[block_idx]);
|
| 739 |
+
|
| 740 |
+
// Load a key to registers.
|
| 741 |
+
// Each thread in a thread group has a different part of the key.
|
| 742 |
+
// For example, if the thread group size is 4, then the first thread in the
|
| 743 |
+
// group has 0, 4, 8, ... th vectors of the key, and the second thread has
|
| 744 |
+
// 1, 5, 9, ... th vectors of the key, and so on.
|
| 745 |
+
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
| 746 |
+
const int physical_block_offset =
|
| 747 |
+
(thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE;
|
| 748 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
| 749 |
+
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
| 750 |
+
|
| 751 |
+
#pragma unroll
|
| 752 |
+
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
| 753 |
+
const device T *k_ptr =
|
| 754 |
+
k_cache + physical_block_number * kv_block_stride +
|
| 755 |
+
kv_head_idx * kv_head_stride + physical_block_offset * x;
|
| 756 |
+
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
| 757 |
+
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
| 758 |
+
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
| 759 |
+
k_vecs[j] = *reinterpret_cast<const device K_vec *>(
|
| 760 |
+
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
// Compute dot product.
|
| 764 |
+
// This includes a reduction across the threads in the same thread group.
|
| 765 |
+
float qk = scale * Qk_dot<T, THREAD_GROUP_SIZE>::dot(
|
| 766 |
+
q_vecs[thread_group_offset], k_vecs);
|
| 767 |
+
|
| 768 |
+
// Apply softcapping
|
| 769 |
+
if (softcapping != 1.0) {
|
| 770 |
+
qk = precise::tanh(qk / softcapping) * softcapping;
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
// Add the ALiBi bias if slopes are given.
|
| 774 |
+
qk +=
|
| 775 |
+
(alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
| 776 |
+
|
| 777 |
+
if (thread_group_offset == 0) {
|
| 778 |
+
// Store the partial reductions to shared memory.
|
| 779 |
+
// NOTE: It is required to zero out the masked logits.
|
| 780 |
+
const bool mask = token_idx >= context_len;
|
| 781 |
+
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
| 782 |
+
// Update the max value.
|
| 783 |
+
qk_max = mask ? qk_max : max(qk_max, qk);
|
| 784 |
+
}
|
| 785 |
+
}
|
| 786 |
+
}
|
| 787 |
+
|
| 788 |
+
// Perform reduction across the threads in the same warp to get the
|
| 789 |
+
// max qk value for each "warp" (not across the thread block yet).
|
| 790 |
+
// The 0-th thread of each thread group already has its max qk value.
|
| 791 |
+
#pragma unroll
|
| 792 |
+
for (int mask = NUM_SIMD_LANES / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
| 793 |
+
qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask));
|
| 794 |
+
}
|
| 795 |
+
if (lane == 0) {
|
| 796 |
+
red_smem[warp_idx] = qk_max;
|
| 797 |
+
}
|
| 798 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 799 |
+
|
| 800 |
+
// Get the max qk value for the sequence.
|
| 801 |
+
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
| 802 |
+
#pragma unroll
|
| 803 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
| 804 |
+
qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask));
|
| 805 |
+
}
|
| 806 |
+
// Broadcast the max qk value to all threads.
|
| 807 |
+
qk_max = simd_shuffle(qk_max, 0);
|
| 808 |
+
|
| 809 |
+
// Get the sum of the exp values.
|
| 810 |
+
float exp_sum = 0.f;
|
| 811 |
+
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
| 812 |
+
float val = exp(logits[i] - qk_max);
|
| 813 |
+
logits[i] = val;
|
| 814 |
+
exp_sum += val;
|
| 815 |
+
}
|
| 816 |
+
exp_sum = block_sum<NUM_WARPS, NUM_SIMD_LANES>(&red_smem[NUM_WARPS], exp_sum,
|
| 817 |
+
simd_tid, simd_lid);
|
| 818 |
+
|
| 819 |
+
// Compute softmax.
|
| 820 |
+
const float inv_sum = divide(1.f, exp_sum + 1e-6f);
|
| 821 |
+
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
| 822 |
+
logits[i] *= inv_sum;
|
| 823 |
+
}
|
| 824 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 825 |
+
|
| 826 |
+
// If partitioning is enabled, store the max logit and exp_sum.
|
| 827 |
+
if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) {
|
| 828 |
+
device float *max_logits_ptr =
|
| 829 |
+
max_logits + seq_idx * num_heads * max_num_partitions +
|
| 830 |
+
head_idx * max_num_partitions + partition_idx;
|
| 831 |
+
*max_logits_ptr = qk_max;
|
| 832 |
+
device float *exp_sums_ptr = exp_sums +
|
| 833 |
+
seq_idx * num_heads * max_num_partitions +
|
| 834 |
+
head_idx * max_num_partitions + partition_idx;
|
| 835 |
+
*exp_sums_ptr = exp_sum;
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
// Each thread will fetch 16 bytes from the value cache at a time.
|
| 839 |
+
constexpr int V_VEC_SIZE = MIN(16 / sizeof(T), BLOCK_SIZE);
|
| 840 |
+
using V_vec = typename Vec<T, V_VEC_SIZE>::Type;
|
| 841 |
+
using L_vec = typename Vec<T, V_VEC_SIZE>::Type;
|
| 842 |
+
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
| 843 |
+
|
| 844 |
+
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
| 845 |
+
constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW;
|
| 846 |
+
constexpr int NUM_ROWS_PER_THREAD =
|
| 847 |
+
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
| 848 |
+
|
| 849 |
+
// NOTE: We use FP32 for the accumulator for better accuracy.
|
| 850 |
+
float accs[NUM_ROWS_PER_THREAD];
|
| 851 |
+
#pragma unroll
|
| 852 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 853 |
+
accs[i] = 0.f;
|
| 854 |
+
}
|
| 855 |
+
|
| 856 |
+
T zero_value = 0;
|
| 857 |
+
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
| 858 |
+
block_idx += NUM_WARPS) {
|
| 859 |
+
// NOTE: The block number is stored in int32. However, we cast it to int64
|
| 860 |
+
// because int32 can lead to overflow when this variable is multiplied by
|
| 861 |
+
// large numbers (e.g., kv_block_stride).
|
| 862 |
+
const int64_t physical_block_number =
|
| 863 |
+
static_cast<int64_t>(block_table[block_idx]);
|
| 864 |
+
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
| 865 |
+
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
| 866 |
+
L_vec logits_vec;
|
| 867 |
+
Float_L_vec logits_float_vec = *reinterpret_cast<threadgroup Float_L_vec *>(
|
| 868 |
+
logits + token_idx - start_token_idx);
|
| 869 |
+
from_float(logits_vec, logits_float_vec);
|
| 870 |
+
|
| 871 |
+
const device T *v_ptr = v_cache + physical_block_number * kv_block_stride +
|
| 872 |
+
kv_head_idx * kv_head_stride;
|
| 873 |
+
#pragma unroll
|
| 874 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 875 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
| 876 |
+
if (row_idx < HEAD_SIZE) {
|
| 877 |
+
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
| 878 |
+
// NOTE: When v_vec contains the tokens that are out of the context,
|
| 879 |
+
// we should explicitly zero out the values since they may contain NaNs.
|
| 880 |
+
// See
|
| 881 |
+
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
| 882 |
+
V_vec v_vec = *reinterpret_cast<const device V_vec *>(v_ptr + offset);
|
| 883 |
+
if (block_idx == num_context_blocks - 1) {
|
| 884 |
+
thread T *v_vec_ptr = reinterpret_cast<thread T *>(&v_vec);
|
| 885 |
+
#pragma unroll
|
| 886 |
+
for (int j = 0; j < V_VEC_SIZE; j++) {
|
| 887 |
+
v_vec_ptr[j] =
|
| 888 |
+
token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
| 889 |
+
}
|
| 890 |
+
}
|
| 891 |
+
accs[i] += dot(logits_vec, v_vec);
|
| 892 |
+
}
|
| 893 |
+
}
|
| 894 |
+
}
|
| 895 |
+
|
| 896 |
+
// Perform reduction within each warp.
|
| 897 |
+
#pragma unroll
|
| 898 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 899 |
+
float acc = accs[i];
|
| 900 |
+
#pragma unroll
|
| 901 |
+
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
| 902 |
+
acc += simd_shuffle_xor(acc, mask);
|
| 903 |
+
}
|
| 904 |
+
accs[i] = acc;
|
| 905 |
+
}
|
| 906 |
+
|
| 907 |
+
// NOTE: A barrier is required because the shared memory space for logits
|
| 908 |
+
// is reused for the output.
|
| 909 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 910 |
+
|
| 911 |
+
// Perform reduction across warps.
|
| 912 |
+
threadgroup float *out_smem =
|
| 913 |
+
reinterpret_cast<threadgroup float *>(shared_mem);
|
| 914 |
+
#pragma unroll
|
| 915 |
+
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
| 916 |
+
int mid = i / 2;
|
| 917 |
+
// Upper warps write to shared memory.
|
| 918 |
+
if (warp_idx >= mid && warp_idx < i) {
|
| 919 |
+
threadgroup float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
| 920 |
+
#pragma unroll
|
| 921 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 922 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
| 923 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
| 924 |
+
dst[row_idx] = accs[i];
|
| 925 |
+
}
|
| 926 |
+
}
|
| 927 |
+
}
|
| 928 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 929 |
+
|
| 930 |
+
// Lower warps update the output.
|
| 931 |
+
if (warp_idx < mid) {
|
| 932 |
+
const threadgroup float *src = &out_smem[warp_idx * HEAD_SIZE];
|
| 933 |
+
#pragma unroll
|
| 934 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 935 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
| 936 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
| 937 |
+
accs[i] += src[row_idx];
|
| 938 |
+
}
|
| 939 |
+
}
|
| 940 |
+
}
|
| 941 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 942 |
+
}
|
| 943 |
+
|
| 944 |
+
// Write the final output.
|
| 945 |
+
if (warp_idx == 0) {
|
| 946 |
+
device T *out_ptr =
|
| 947 |
+
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
| 948 |
+
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
|
| 949 |
+
#pragma unroll
|
| 950 |
+
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
| 951 |
+
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
| 952 |
+
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
| 953 |
+
*(out_ptr + row_idx) = T(accs[i]);
|
| 954 |
+
}
|
| 955 |
+
}
|
| 956 |
+
}
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
|
| 960 |
+
int PARTITION_SIZE = 0>
|
| 961 |
+
[[kernel]] void paged_attention_v2_reduce(
|
| 962 |
+
device T *out [[buffer(0)]], const device float *exp_sums [[buffer(1)]],
|
| 963 |
+
const device float *max_logits [[buffer(2)]],
|
| 964 |
+
const device T *tmp_out [[buffer(3)]],
|
| 965 |
+
device uint32_t *context_lens [[buffer(4)]],
|
| 966 |
+
const constant int &max_num_partitions [[buffer(5)]],
|
| 967 |
+
threadgroup char *shared_mem [[threadgroup(0)]],
|
| 968 |
+
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
|
| 969 |
+
uint3 threadgroups_per_grid [[threadgroups_per_grid]],
|
| 970 |
+
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
|
| 971 |
+
uint3 threads_per_threadgroup [[threads_per_threadgroup]],
|
| 972 |
+
uint simd_tid [[simdgroup_index_in_threadgroup]],
|
| 973 |
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
| 974 |
+
const int num_heads = threadgroups_per_grid.x;
|
| 975 |
+
const int head_idx = threadgroup_position_in_grid.x;
|
| 976 |
+
const int seq_idx = threadgroup_position_in_grid.y;
|
| 977 |
+
const uint32_t context_len = context_lens[seq_idx];
|
| 978 |
+
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
| 979 |
+
if (num_partitions == 1) {
|
| 980 |
+
// No need to reduce. Only copy tmp_out to out.
|
| 981 |
+
device T *out_ptr =
|
| 982 |
+
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
| 983 |
+
const device T *tmp_out_ptr =
|
| 984 |
+
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
| 985 |
+
head_idx * max_num_partitions * HEAD_SIZE;
|
| 986 |
+
for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
|
| 987 |
+
i += threads_per_threadgroup.x) {
|
| 988 |
+
out_ptr[i] = tmp_out_ptr[i];
|
| 989 |
+
}
|
| 990 |
+
// Terminate the thread block.
|
| 991 |
+
return;
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES;
|
| 995 |
+
const int warp_idx = simd_tid;
|
| 996 |
+
const int lane = simd_lid;
|
| 997 |
+
|
| 998 |
+
// Workspace for reduction.
|
| 999 |
+
threadgroup float red_smem[2 * NUM_WARPS];
|
| 1000 |
+
|
| 1001 |
+
// Load max logits to shared memory.
|
| 1002 |
+
threadgroup float *shared_max_logits =
|
| 1003 |
+
reinterpret_cast<threadgroup float *>(shared_mem);
|
| 1004 |
+
const device float *max_logits_ptr =
|
| 1005 |
+
max_logits + seq_idx * num_heads * max_num_partitions +
|
| 1006 |
+
head_idx * max_num_partitions;
|
| 1007 |
+
float max_logit = -FLT_MAX;
|
| 1008 |
+
for (int i = thread_position_in_threadgroup.x; i < num_partitions;
|
| 1009 |
+
i += threads_per_threadgroup.x) {
|
| 1010 |
+
const float l = max_logits_ptr[i];
|
| 1011 |
+
shared_max_logits[i] = l;
|
| 1012 |
+
max_logit = max(max_logit, l);
|
| 1013 |
+
}
|
| 1014 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1015 |
+
|
| 1016 |
+
// Get the global max logit.
|
| 1017 |
+
// Reduce within the warp.
|
| 1018 |
+
#pragma unroll
|
| 1019 |
+
for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) {
|
| 1020 |
+
max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask));
|
| 1021 |
+
}
|
| 1022 |
+
if (lane == 0) {
|
| 1023 |
+
red_smem[warp_idx] = max_logit;
|
| 1024 |
+
}
|
| 1025 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1026 |
+
// Reduce across warps.
|
| 1027 |
+
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
| 1028 |
+
#pragma unroll
|
| 1029 |
+
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
| 1030 |
+
max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask));
|
| 1031 |
+
}
|
| 1032 |
+
// Broadcast the max value to all threads.
|
| 1033 |
+
max_logit = simd_shuffle(max_logit, 0);
|
| 1034 |
+
|
| 1035 |
+
// Load rescaled exp sums to shared memory.
|
| 1036 |
+
threadgroup float *shared_exp_sums = reinterpret_cast<threadgroup float *>(
|
| 1037 |
+
shared_mem + sizeof(float) * num_partitions);
|
| 1038 |
+
const device float *exp_sums_ptr = exp_sums +
|
| 1039 |
+
seq_idx * num_heads * max_num_partitions +
|
| 1040 |
+
head_idx * max_num_partitions;
|
| 1041 |
+
float global_exp_sum = 0.0f;
|
| 1042 |
+
for (int i = thread_position_in_threadgroup.x; i < num_partitions;
|
| 1043 |
+
i += threads_per_threadgroup.x) {
|
| 1044 |
+
float l = shared_max_logits[i];
|
| 1045 |
+
float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit);
|
| 1046 |
+
global_exp_sum += rescaled_exp_sum;
|
| 1047 |
+
shared_exp_sums[i] = rescaled_exp_sum;
|
| 1048 |
+
}
|
| 1049 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1050 |
+
global_exp_sum = block_sum<NUM_WARPS, NUM_SIMD_LANES>(
|
| 1051 |
+
&red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid);
|
| 1052 |
+
const float inv_global_exp_sum = divide(1.0f, global_exp_sum + 1e-6f);
|
| 1053 |
+
|
| 1054 |
+
// Aggregate tmp_out to out.
|
| 1055 |
+
const device T *tmp_out_ptr =
|
| 1056 |
+
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
| 1057 |
+
head_idx * max_num_partitions * HEAD_SIZE;
|
| 1058 |
+
device T *out_ptr =
|
| 1059 |
+
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
| 1060 |
+
#pragma unroll
|
| 1061 |
+
for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
|
| 1062 |
+
i += NUM_THREADS) {
|
| 1063 |
+
float acc = 0.0f;
|
| 1064 |
+
for (int j = 0; j < num_partitions; ++j) {
|
| 1065 |
+
acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
|
| 1066 |
+
inv_global_exp_sum;
|
| 1067 |
+
}
|
| 1068 |
+
out_ptr[i] = T(acc);
|
| 1069 |
+
}
|
| 1070 |
+
}
|
| 1071 |
+
|
| 1072 |
+
#define instantiate_paged_attention_inner( \
|
| 1073 |
+
type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \
|
| 1074 |
+
template \
|
| 1075 |
+
[[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \
|
| 1076 |
+
"_nt" #num_threads "_nsl" #num_simd_lanes \
|
| 1077 |
+
"_ps" #partition_size)]] [[kernel]] void \
|
| 1078 |
+
paged_attention<type, head_size, block_size, num_threads, \
|
| 1079 |
+
num_simd_lanes, partition_size>( \
|
| 1080 |
+
device float *exp_sums \
|
| 1081 |
+
[[buffer(0), function_constant(use_partitioning)]], \
|
| 1082 |
+
device float *max_logits \
|
| 1083 |
+
[[buffer(1), function_constant(use_partitioning)]], \
|
| 1084 |
+
device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \
|
| 1085 |
+
device const type *k_cache [[buffer(4)]], \
|
| 1086 |
+
device const type *v_cache [[buffer(5)]], \
|
| 1087 |
+
const constant int &num_kv_heads [[buffer(6)]], \
|
| 1088 |
+
const constant float &scale [[buffer(7)]], \
|
| 1089 |
+
const constant float &softcapping [[buffer(8)]], \
|
| 1090 |
+
device const uint32_t *block_tables [[buffer(9)]], \
|
| 1091 |
+
device const uint32_t *context_lens [[buffer(10)]], \
|
| 1092 |
+
const constant int &max_num_blocks_per_seq [[buffer(11)]], \
|
| 1093 |
+
device const float *alibi_slopes \
|
| 1094 |
+
[[buffer(12), function_constant(use_alibi)]], \
|
| 1095 |
+
const constant int &q_stride [[buffer(13)]], \
|
| 1096 |
+
const constant int &kv_block_stride [[buffer(14)]], \
|
| 1097 |
+
const constant int &kv_head_stride [[buffer(15)]], \
|
| 1098 |
+
threadgroup char *shared_mem [[threadgroup(0)]], \
|
| 1099 |
+
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
| 1100 |
+
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
|
| 1101 |
+
uint3 thread_position_in_threadgroup \
|
| 1102 |
+
[[thread_position_in_threadgroup]], \
|
| 1103 |
+
uint simd_tid [[simdgroup_index_in_threadgroup]], \
|
| 1104 |
+
uint simd_lid [[thread_index_in_simdgroup]]);
|
| 1105 |
+
|
| 1106 |
+
#define instantiate_paged_attention_v2_reduce_inner( \
|
| 1107 |
+
type, head_size, num_threads, num_simd_lanes, partition_size) \
|
| 1108 |
+
template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
|
| 1109 |
+
"_nt" #num_threads "_nsl" #num_simd_lanes \
|
| 1110 |
+
"_ps" #partition_size)]] [[kernel]] void \
|
| 1111 |
+
paged_attention_v2_reduce<type, head_size, num_threads, num_simd_lanes, \
|
| 1112 |
+
partition_size>( \
|
| 1113 |
+
device type * out [[buffer(0)]], \
|
| 1114 |
+
const device float *exp_sums [[buffer(1)]], \
|
| 1115 |
+
const device float *max_logits [[buffer(2)]], \
|
| 1116 |
+
const device type *tmp_out [[buffer(3)]], \
|
| 1117 |
+
device uint32_t *context_lens [[buffer(4)]], \
|
| 1118 |
+
const constant int &max_num_partitions [[buffer(5)]], \
|
| 1119 |
+
threadgroup char *shared_mem [[threadgroup(0)]], \
|
| 1120 |
+
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
|
| 1121 |
+
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
|
| 1122 |
+
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
|
| 1123 |
+
uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
|
| 1124 |
+
uint simd_tid [[simdgroup_index_in_threadgroup]], \
|
| 1125 |
+
uint simd_lid [[thread_index_in_simdgroup]]);
|
| 1126 |
+
|
| 1127 |
+
#define instantiate_paged_attention_heads(type, block_size, num_threads, \
|
| 1128 |
+
num_simd_lanes, partition_size) \
|
| 1129 |
+
instantiate_paged_attention_inner(type, 64, block_size, num_threads, \
|
| 1130 |
+
num_simd_lanes, partition_size); \
|
| 1131 |
+
instantiate_paged_attention_inner(type, 80, block_size, num_threads, \
|
| 1132 |
+
num_simd_lanes, partition_size); \
|
| 1133 |
+
instantiate_paged_attention_inner(type, 96, block_size, num_threads, \
|
| 1134 |
+
num_simd_lanes, partition_size); \
|
| 1135 |
+
instantiate_paged_attention_inner(type, 112, block_size, num_threads, \
|
| 1136 |
+
num_simd_lanes, partition_size); \
|
| 1137 |
+
instantiate_paged_attention_inner(type, 128, block_size, num_threads, \
|
| 1138 |
+
num_simd_lanes, partition_size); \
|
| 1139 |
+
instantiate_paged_attention_inner(type, 192, block_size, num_threads, \
|
| 1140 |
+
num_simd_lanes, partition_size); \
|
| 1141 |
+
instantiate_paged_attention_inner(type, 256, block_size, num_threads, \
|
| 1142 |
+
num_simd_lanes, partition_size);
|
| 1143 |
+
|
| 1144 |
+
#define instantiate_paged_attention_v2_reduce_heads( \
|
| 1145 |
+
type, num_threads, num_simd_lanes, partition_size) \
|
| 1146 |
+
instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, \
|
| 1147 |
+
num_simd_lanes, partition_size); \
|
| 1148 |
+
instantiate_paged_attention_v2_reduce_inner(type, 80, num_threads, \
|
| 1149 |
+
num_simd_lanes, partition_size); \
|
| 1150 |
+
instantiate_paged_attention_v2_reduce_inner(type, 96, num_threads, \
|
| 1151 |
+
num_simd_lanes, partition_size); \
|
| 1152 |
+
instantiate_paged_attention_v2_reduce_inner(type, 112, num_threads, \
|
| 1153 |
+
num_simd_lanes, partition_size); \
|
| 1154 |
+
instantiate_paged_attention_v2_reduce_inner(type, 128, num_threads, \
|
| 1155 |
+
num_simd_lanes, partition_size); \
|
| 1156 |
+
instantiate_paged_attention_v2_reduce_inner(type, 192, num_threads, \
|
| 1157 |
+
num_simd_lanes, partition_size); \
|
| 1158 |
+
instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \
|
| 1159 |
+
num_simd_lanes, partition_size);
|
| 1160 |
+
|
| 1161 |
+
#define instantiate_paged_attention_block_size(type, num_threads, \
|
| 1162 |
+
num_simd_lanes, partition_size) \
|
| 1163 |
+
instantiate_paged_attention_heads(type, 8, num_threads, num_simd_lanes, \
|
| 1164 |
+
partition_size); \
|
| 1165 |
+
instantiate_paged_attention_heads(type, 16, num_threads, num_simd_lanes, \
|
| 1166 |
+
partition_size); \
|
| 1167 |
+
instantiate_paged_attention_heads(type, 32, num_threads, num_simd_lanes, \
|
| 1168 |
+
partition_size);
|
| 1169 |
+
|
| 1170 |
+
// TODO: tune num_threads = 256
|
| 1171 |
+
// NOTE: partition_size = 0
|
| 1172 |
+
#define instantiate_paged_attention_v1(type, num_simd_lanes) \
|
| 1173 |
+
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0);
|
| 1174 |
+
|
| 1175 |
+
// TODO: tune num_threads = 256
|
| 1176 |
+
// NOTE: partition_size = 512
|
| 1177 |
+
#define instantiate_paged_attention_v2(type, num_simd_lanes) \
|
| 1178 |
+
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \
|
| 1179 |
+
instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
|
| 1180 |
+
|
| 1181 |
+
instantiate_paged_attention_v1(float, 32);
|
| 1182 |
+
instantiate_paged_attention_v1(bfloat16_t, 32);
|
| 1183 |
+
instantiate_paged_attention_v1(half, 32);
|
| 1184 |
+
|
| 1185 |
+
instantiate_paged_attention_v2(float, 32);
|
| 1186 |
+
instantiate_paged_attention_v2(bfloat16_t, 32);
|
| 1187 |
+
instantiate_paged_attention_v2(half, 32);
|
paged-attention-metal/cache/copy_blocks.metal
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "utils.metal"
|
| 2 |
+
#include <metal_stdlib>
|
| 3 |
+
|
| 4 |
+
using namespace metal;
|
| 5 |
+
|
| 6 |
+
template <typename T>
|
| 7 |
+
[[kernel]] void copy_blocks(device T *key_cache [[buffer(0)]],
|
| 8 |
+
device T *value_cache [[buffer(1)]],
|
| 9 |
+
const device int64_t *block_mapping [[buffer(2)]],
|
| 10 |
+
device const int &numel_per_block,
|
| 11 |
+
uint gid [[thread_position_in_grid]],
|
| 12 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 13 |
+
uint threads_per_threadgroup
|
| 14 |
+
[[threads_per_threadgroup]]) {
|
| 15 |
+
const int pair_idx = gid;
|
| 16 |
+
|
| 17 |
+
int64_t src_block_number = block_mapping[2 * pair_idx];
|
| 18 |
+
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
| 19 |
+
|
| 20 |
+
const int64_t src_block_offset = src_block_number * numel_per_block;
|
| 21 |
+
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
| 22 |
+
|
| 23 |
+
// Copy key cache blocks
|
| 24 |
+
for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) {
|
| 25 |
+
int64_t src_offset = src_block_offset + i;
|
| 26 |
+
int64_t dst_offset = dst_block_offset + i;
|
| 27 |
+
key_cache[dst_offset] = key_cache[src_offset];
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// Copy value cache blocks
|
| 31 |
+
for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) {
|
| 32 |
+
int64_t src_offset = src_block_offset + i;
|
| 33 |
+
int64_t dst_offset = dst_block_offset + i;
|
| 34 |
+
value_cache[dst_offset] = value_cache[src_offset];
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
#define instantiate_copy_blocks(type) \
|
| 39 |
+
template [[host_name("copy_blocks_" #type)]] [[kernel]] void \
|
| 40 |
+
copy_blocks<type>(device type * key_cache_ptrs [[buffer(0)]], \
|
| 41 |
+
device type * value_cache_ptrs [[buffer(1)]], \
|
| 42 |
+
const device int64_t *block_mapping [[buffer(2)]], \
|
| 43 |
+
device const int &numel_per_block, \
|
| 44 |
+
uint gid [[thread_position_in_grid]], \
|
| 45 |
+
uint tid [[thread_position_in_threadgroup]], \
|
| 46 |
+
uint threads_per_threadgroup [[threads_per_threadgroup]]);
|
| 47 |
+
|
| 48 |
+
instantiate_copy_blocks(float);
|
| 49 |
+
instantiate_copy_blocks(bfloat16_t);
|
| 50 |
+
instantiate_copy_blocks(half);
|
paged-attention-metal/cache/reshape_and_cache.metal
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "utils.metal"
|
| 2 |
+
#include <metal_stdlib>
|
| 3 |
+
|
| 4 |
+
using namespace metal;
|
| 5 |
+
|
| 6 |
+
template <typename T>
|
| 7 |
+
[[kernel]] void reshape_and_cache(
|
| 8 |
+
const device T *__restrict__ key
|
| 9 |
+
[[buffer(0)]], // [num_tokens, num_heads, head_size]
|
| 10 |
+
const device T *__restrict__ value
|
| 11 |
+
[[buffer(1)]], // [num_tokens, num_heads, head_size]
|
| 12 |
+
device T *__restrict__ key_cache
|
| 13 |
+
[[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x]
|
| 14 |
+
device T *__restrict__ value_cache
|
| 15 |
+
[[buffer(3)]], // [num_blocks, num_heads, head_size, block_size]
|
| 16 |
+
const device int64_t *__restrict__ slot_mapping
|
| 17 |
+
[[buffer(4)]], // [num_tokens]
|
| 18 |
+
device const int &key_stride, device const int &value_stride,
|
| 19 |
+
device const int &num_heads, device const int &head_size,
|
| 20 |
+
device const int &block_size, device const int &x,
|
| 21 |
+
uint gid [[threadgroup_position_in_grid]],
|
| 22 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 23 |
+
uint threads_per_threadgroup [[threads_per_threadgroup]]) {
|
| 24 |
+
const int64_t token_idx = gid;
|
| 25 |
+
const int64_t slot_idx = slot_mapping[token_idx];
|
| 26 |
+
if (slot_idx < 0) {
|
| 27 |
+
// Padding token that should be ignored.
|
| 28 |
+
return;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
const int64_t block_idx = slot_idx / block_size;
|
| 32 |
+
const int64_t block_offset = slot_idx % block_size;
|
| 33 |
+
|
| 34 |
+
const int n = num_heads * head_size;
|
| 35 |
+
for (int i = tid; i < n; i += threads_per_threadgroup) {
|
| 36 |
+
const int64_t src_key_idx = token_idx * key_stride + i;
|
| 37 |
+
const int64_t src_value_idx = token_idx * value_stride + i;
|
| 38 |
+
|
| 39 |
+
const int head_idx = i / head_size;
|
| 40 |
+
const int head_offset = i % head_size;
|
| 41 |
+
const int x_idx = head_offset / x;
|
| 42 |
+
const int x_offset = head_offset % x;
|
| 43 |
+
|
| 44 |
+
const int64_t tgt_key_idx =
|
| 45 |
+
block_idx * num_heads * (head_size / x) * block_size * x +
|
| 46 |
+
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
| 47 |
+
block_offset * x + x_offset;
|
| 48 |
+
const int64_t tgt_value_idx =
|
| 49 |
+
block_idx * num_heads * head_size * block_size +
|
| 50 |
+
head_idx * head_size * block_size + head_offset * block_size +
|
| 51 |
+
block_offset;
|
| 52 |
+
key_cache[tgt_key_idx] = key[src_key_idx];
|
| 53 |
+
value_cache[tgt_value_idx] = value[src_value_idx];
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
#define instantiate_reshape_and_cache(type) \
|
| 58 |
+
template [[host_name("reshape_and_cache_" #type)]] [[kernel]] void \
|
| 59 |
+
reshape_and_cache<type>( \
|
| 60 |
+
const device type *__restrict__ key [[buffer(0)]], \
|
| 61 |
+
const device type *__restrict__ value [[buffer(1)]], \
|
| 62 |
+
device type *__restrict__ key_cache [[buffer(2)]], \
|
| 63 |
+
device type *__restrict__ value_cache [[buffer(3)]], \
|
| 64 |
+
const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \
|
| 65 |
+
device const int &key_stride, device const int &value_stride, \
|
| 66 |
+
device const int &num_heads, device const int &head_size, \
|
| 67 |
+
device const int &block_size, device const int &x, \
|
| 68 |
+
uint gid [[threadgroup_position_in_grid]], \
|
| 69 |
+
uint tid [[thread_position_in_threadgroup]], \
|
| 70 |
+
uint threads_per_threadgroup [[threads_per_threadgroup]]);
|
| 71 |
+
|
| 72 |
+
instantiate_reshape_and_cache(float);
|
| 73 |
+
instantiate_reshape_and_cache(bfloat16_t);
|
| 74 |
+
instantiate_reshape_and_cache(half);
|
paged-attention-metal/paged_attention.mm
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/torch.h>
|
| 2 |
+
|
| 3 |
+
#import <Foundation/Foundation.h>
|
| 4 |
+
#import <Metal/Metal.h>
|
| 5 |
+
#include <string>
|
| 6 |
+
|
| 7 |
+
char const *CUSTOM_KERNEL = R"(
|
| 8 |
+
#include <metal_stdlib>
|
| 9 |
+
using namespace metal;
|
| 10 |
+
kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]],
|
| 11 |
+
device float *outC [[buffer(1)]],
|
| 12 |
+
uint index [[thread_position_in_grid]]) {
|
| 13 |
+
// Explicitly write to output
|
| 14 |
+
outC[index] = max(0.0f, inA[index]);
|
| 15 |
+
}
|
| 16 |
+
kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]],
|
| 17 |
+
device half *outC [[buffer(1)]],
|
| 18 |
+
uint index [[thread_position_in_grid]]) {
|
| 19 |
+
// Explicitly write to output
|
| 20 |
+
outC[index] = max(static_cast<half>(0.0), inA[index]);
|
| 21 |
+
}
|
| 22 |
+
)";
|
| 23 |
+
|
| 24 |
+
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
|
| 25 |
+
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
torch::Tensor &dispatchReluKernel(torch::Tensor const &input,
|
| 29 |
+
torch::Tensor &output) {
|
| 30 |
+
@autoreleasepool {
|
| 31 |
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
| 32 |
+
NSError *error = nil;
|
| 33 |
+
|
| 34 |
+
int numThreads = input.numel();
|
| 35 |
+
|
| 36 |
+
id<MTLLibrary> customKernelLibrary = [device
|
| 37 |
+
newLibraryWithSource:[NSString stringWithUTF8String:CUSTOM_KERNEL]
|
| 38 |
+
options:nil
|
| 39 |
+
error:&error];
|
| 40 |
+
TORCH_CHECK(customKernelLibrary,
|
| 41 |
+
"Failed to to create custom kernel library, error: ",
|
| 42 |
+
error.localizedDescription.UTF8String);
|
| 43 |
+
|
| 44 |
+
std::string kernel_name =
|
| 45 |
+
std::string("relu_forward_kernel_") +
|
| 46 |
+
(input.scalar_type() == torch::kFloat ? "float" : "half");
|
| 47 |
+
id<MTLFunction> customReluFunction = [customKernelLibrary
|
| 48 |
+
newFunctionWithName:[NSString
|
| 49 |
+
stringWithUTF8String:kernel_name.c_str()]];
|
| 50 |
+
TORCH_CHECK(customReluFunction,
|
| 51 |
+
"Failed to create function state object for ",
|
| 52 |
+
kernel_name.c_str());
|
| 53 |
+
|
| 54 |
+
id<MTLComputePipelineState> reluPSO =
|
| 55 |
+
[device newComputePipelineStateWithFunction:customReluFunction
|
| 56 |
+
error:&error];
|
| 57 |
+
TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String);
|
| 58 |
+
|
| 59 |
+
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
|
| 60 |
+
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
|
| 61 |
+
|
| 62 |
+
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
|
| 63 |
+
|
| 64 |
+
dispatch_sync(serialQueue, ^() {
|
| 65 |
+
id<MTLComputeCommandEncoder> computeEncoder =
|
| 66 |
+
[commandBuffer computeCommandEncoder];
|
| 67 |
+
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
|
| 68 |
+
|
| 69 |
+
[computeEncoder setComputePipelineState:reluPSO];
|
| 70 |
+
[computeEncoder setBuffer:getMTLBufferStorage(input)
|
| 71 |
+
offset:input.storage_offset() * input.element_size()
|
| 72 |
+
atIndex:0];
|
| 73 |
+
[computeEncoder setBuffer:getMTLBufferStorage(output)
|
| 74 |
+
offset:output.storage_offset() * output.element_size()
|
| 75 |
+
atIndex:1];
|
| 76 |
+
|
| 77 |
+
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
|
| 78 |
+
|
| 79 |
+
NSUInteger threadGroupSize = reluPSO.maxTotalThreadsPerThreadgroup;
|
| 80 |
+
if (threadGroupSize > numThreads) {
|
| 81 |
+
threadGroupSize = numThreads;
|
| 82 |
+
}
|
| 83 |
+
MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
| 84 |
+
|
| 85 |
+
[computeEncoder dispatchThreads:gridSize
|
| 86 |
+
threadsPerThreadgroup:threadgroupSize];
|
| 87 |
+
|
| 88 |
+
[computeEncoder endEncoding];
|
| 89 |
+
|
| 90 |
+
torch::mps::commit();
|
| 91 |
+
});
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
return output;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
void relu(torch::Tensor &out, const torch::Tensor &input) {
|
| 98 |
+
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
|
| 99 |
+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
|
| 100 |
+
TORCH_CHECK(input.scalar_type() == torch::kFloat ||
|
| 101 |
+
input.scalar_type() == torch::kHalf,
|
| 102 |
+
"Unsupported data type: ", input.scalar_type());
|
| 103 |
+
|
| 104 |
+
TORCH_CHECK(input.sizes() == out.sizes(),
|
| 105 |
+
"Tensors must have the same shape. Got input shape: ",
|
| 106 |
+
input.sizes(), " and output shape: ", out.sizes());
|
| 107 |
+
|
| 108 |
+
TORCH_CHECK(input.scalar_type() == out.scalar_type(),
|
| 109 |
+
"Tensors must have the same data type. Got input dtype: ",
|
| 110 |
+
input.scalar_type(), " and output dtype: ", out.scalar_type());
|
| 111 |
+
|
| 112 |
+
TORCH_CHECK(input.device() == out.device(),
|
| 113 |
+
"Tensors must be on the same device. Got input device: ",
|
| 114 |
+
input.device(), " and output device: ", out.device());
|
| 115 |
+
|
| 116 |
+
dispatchReluKernel(input, out);
|
| 117 |
+
}
|
paged-attention-metal/utils.metal
ADDED
|
File without changes
|
tests/kernels/__init__.py
ADDED
|
File without changes
|
tests/kernels/allclose_default.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# Reference default values of atol and rtol are from
|
| 4 |
+
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
|
| 5 |
+
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
|
| 6 |
+
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_default_atol(output) -> float:
|
| 10 |
+
return default_atol[output.dtype]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_default_rtol(output) -> float:
|
| 14 |
+
return default_rtol[output.dtype]
|
tests/kernels/conftest.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import paged_attention as ops
|
| 4 |
+
import pytest
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture()
|
| 9 |
+
def kv_cache_factory():
|
| 10 |
+
return create_kv_caches_with_random
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@pytest.fixture()
|
| 14 |
+
def kv_cache_factory_flashinfer():
|
| 15 |
+
return create_kv_caches_with_random_flash
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
STR_DTYPE_TO_TORCH_DTYPE = {
|
| 19 |
+
"half": torch.half,
|
| 20 |
+
"bfloat16": torch.bfloat16,
|
| 21 |
+
"float": torch.float,
|
| 22 |
+
"fp8": torch.uint8,
|
| 23 |
+
"fp8_e4m3": torch.uint8,
|
| 24 |
+
"fp8_e5m2": torch.uint8,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_kv_caches_with_random(
|
| 29 |
+
num_blocks: int,
|
| 30 |
+
block_size: int,
|
| 31 |
+
num_layers: int,
|
| 32 |
+
num_heads: int,
|
| 33 |
+
head_size: int,
|
| 34 |
+
cache_dtype: Optional[Union[str, torch.dtype]],
|
| 35 |
+
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 36 |
+
seed: int = 0,
|
| 37 |
+
device: Optional[str] = "cuda",
|
| 38 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
| 39 |
+
|
| 40 |
+
if cache_dtype == "fp8" and head_size % 16:
|
| 41 |
+
raise ValueError(
|
| 42 |
+
f"Does not support key cache of type fp8 with head_size {head_size}"
|
| 43 |
+
)
|
| 44 |
+
from paged_attention.platforms import current_platform
|
| 45 |
+
|
| 46 |
+
current_platform.seed_everything(seed)
|
| 47 |
+
|
| 48 |
+
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
| 49 |
+
|
| 50 |
+
scale = head_size**-0.5
|
| 51 |
+
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
|
| 52 |
+
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
| 53 |
+
key_caches: List[torch.Tensor] = []
|
| 54 |
+
for _ in range(num_layers):
|
| 55 |
+
key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
|
| 56 |
+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
| 57 |
+
key_cache.uniform_(-scale, scale)
|
| 58 |
+
elif cache_dtype == "fp8":
|
| 59 |
+
_generate_random_fp8(key_cache, -scale, scale)
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
| 62 |
+
key_caches.append(key_cache)
|
| 63 |
+
|
| 64 |
+
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
| 65 |
+
value_caches: List[torch.Tensor] = []
|
| 66 |
+
for _ in range(num_layers):
|
| 67 |
+
value_cache = torch.empty(
|
| 68 |
+
size=value_cache_shape, dtype=torch_dtype, device=device
|
| 69 |
+
)
|
| 70 |
+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
| 71 |
+
value_cache.uniform_(-scale, scale)
|
| 72 |
+
elif cache_dtype == "fp8":
|
| 73 |
+
_generate_random_fp8(value_cache, -scale, scale)
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Does not support value cache of type {cache_dtype}")
|
| 76 |
+
value_caches.append(value_cache)
|
| 77 |
+
return key_caches, value_caches
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def create_kv_caches_with_random_flash(
|
| 81 |
+
num_blocks: int,
|
| 82 |
+
block_size: int,
|
| 83 |
+
num_layers: int,
|
| 84 |
+
num_heads: int,
|
| 85 |
+
head_size: int,
|
| 86 |
+
cache_dtype: Optional[Union[str, torch.dtype]],
|
| 87 |
+
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 88 |
+
seed: int = 0,
|
| 89 |
+
device: Optional[str] = "cuda",
|
| 90 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
| 91 |
+
from paged_attention.platforms import current_platform
|
| 92 |
+
|
| 93 |
+
current_platform.seed_everything(seed)
|
| 94 |
+
|
| 95 |
+
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
| 96 |
+
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
| 97 |
+
scale = head_size**-0.5
|
| 98 |
+
|
| 99 |
+
key_caches: List[torch.Tensor] = []
|
| 100 |
+
value_caches: List[torch.Tensor] = []
|
| 101 |
+
|
| 102 |
+
for _ in range(num_layers):
|
| 103 |
+
key_value_cache = torch.empty(
|
| 104 |
+
size=key_value_cache_shape, dtype=torch_dtype, device=device
|
| 105 |
+
)
|
| 106 |
+
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
| 107 |
+
key_value_cache.uniform_(-scale, scale)
|
| 108 |
+
elif cache_dtype == "fp8":
|
| 109 |
+
_generate_random_fp8(key_value_cache, -scale, scale)
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
| 112 |
+
key_caches.append(key_value_cache[:, 0])
|
| 113 |
+
value_caches.append(key_value_cache[:, 1])
|
| 114 |
+
return key_caches, value_caches
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_kv_cache_torch_dtype(
|
| 118 |
+
cache_dtype: Optional[Union[str, torch.dtype]],
|
| 119 |
+
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 120 |
+
) -> torch.dtype:
|
| 121 |
+
if isinstance(cache_dtype, str):
|
| 122 |
+
if cache_dtype == "auto":
|
| 123 |
+
if isinstance(model_dtype, str):
|
| 124 |
+
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
| 125 |
+
elif isinstance(model_dtype, torch.dtype):
|
| 126 |
+
torch_dtype = model_dtype
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Invalid model dtype: {model_dtype}")
|
| 129 |
+
elif cache_dtype in ["half", "bfloat16", "float"]:
|
| 130 |
+
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
| 131 |
+
elif cache_dtype == "fp8":
|
| 132 |
+
torch_dtype = torch.uint8
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
| 135 |
+
elif isinstance(cache_dtype, torch.dtype):
|
| 136 |
+
torch_dtype = cache_dtype
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
| 139 |
+
return torch_dtype
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _generate_random_fp8(
|
| 143 |
+
tensor: torch.Tensor,
|
| 144 |
+
low: float,
|
| 145 |
+
high: float,
|
| 146 |
+
) -> None:
|
| 147 |
+
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
|
| 148 |
+
# it may occur Inf or NaN if we directly use torch.randint
|
| 149 |
+
# to generate random data for fp8 data.
|
| 150 |
+
# For example, s.11111.00 in fp8e5m2 format represents Inf.
|
| 151 |
+
# | E4M3 | E5M2
|
| 152 |
+
# -----|-------------|-------------------
|
| 153 |
+
# Inf | N/A | s.11111.00
|
| 154 |
+
# NaN | s.1111.111 | s.11111.{01,10,11}
|
| 155 |
+
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
| 156 |
+
tensor_tmp.uniform_(low, high)
|
| 157 |
+
ops.convert_fp8(tensor, tensor_tmp)
|
| 158 |
+
del tensor_tmp
|
tests/kernels/test_attention.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import paged_attention as ops
|
| 5 |
+
import pytest
|
| 6 |
+
import torch
|
| 7 |
+
from paged_attention.platforms import current_platform
|
| 8 |
+
|
| 9 |
+
from .allclose_default import get_default_atol, get_default_rtol
|
| 10 |
+
from .utils import get_max_shared_memory_bytes, opcheck
|
| 11 |
+
|
| 12 |
+
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
| 13 |
+
# This will change depending on the compute capability.
|
| 14 |
+
# - 512 as a buffer
|
| 15 |
+
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
| 16 |
+
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
| 17 |
+
# Reduce NUM_BLOCKS when it happens.
|
| 18 |
+
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
| 19 |
+
PARTITION_SIZE = 512
|
| 20 |
+
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
| 21 |
+
DTYPES = (
|
| 22 |
+
[torch.half, torch.bfloat16, torch.float]
|
| 23 |
+
if not current_platform.is_rocm()
|
| 24 |
+
else [torch.half, torch.bfloat16]
|
| 25 |
+
)
|
| 26 |
+
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
| 27 |
+
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
| 28 |
+
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
| 29 |
+
|
| 30 |
+
# This should be sync with get_supported_head_sizes() in
|
| 31 |
+
# vllm.attention.ops.paged_attn.PagedAttention
|
| 32 |
+
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
|
| 33 |
+
|
| 34 |
+
BLOCK_SIZES = [16, 32]
|
| 35 |
+
USE_ALIBI = [False, True]
|
| 36 |
+
KV_CACHE_DTYPE = ["auto", "fp8"]
|
| 37 |
+
SEEDS = [0]
|
| 38 |
+
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def ref_masked_attention(
|
| 42 |
+
query: torch.Tensor,
|
| 43 |
+
key: torch.Tensor,
|
| 44 |
+
value: torch.Tensor,
|
| 45 |
+
scale: float,
|
| 46 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
| 49 |
+
if attn_mask is not None:
|
| 50 |
+
attn_weights = attn_weights + attn_mask.float()
|
| 51 |
+
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
| 52 |
+
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def ref_single_query_cached_kv_attention(
|
| 57 |
+
output: torch.Tensor,
|
| 58 |
+
query: torch.Tensor,
|
| 59 |
+
num_queries_per_kv: int,
|
| 60 |
+
key_cache: torch.Tensor,
|
| 61 |
+
value_cache: torch.Tensor,
|
| 62 |
+
block_tables: torch.Tensor,
|
| 63 |
+
seq_lens: torch.Tensor,
|
| 64 |
+
scale: float,
|
| 65 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 66 |
+
) -> None:
|
| 67 |
+
num_query_heads = query.shape[1]
|
| 68 |
+
num_kv_heads = value_cache.shape[1]
|
| 69 |
+
head_size = value_cache.shape[2]
|
| 70 |
+
block_size = value_cache.shape[3]
|
| 71 |
+
num_seqs = query.shape[0]
|
| 72 |
+
|
| 73 |
+
block_tables_lst = block_tables.cpu().tolist()
|
| 74 |
+
seq_lens_lst = seq_lens.cpu().tolist()
|
| 75 |
+
for i in range(num_seqs):
|
| 76 |
+
q = query[i].unsqueeze(0)
|
| 77 |
+
block_table = block_tables_lst[i]
|
| 78 |
+
seq_len = int(seq_lens_lst[i])
|
| 79 |
+
|
| 80 |
+
keys_lst: List[torch.Tensor] = []
|
| 81 |
+
values_lst: List[torch.Tensor] = []
|
| 82 |
+
for j in range(seq_len):
|
| 83 |
+
block_number = int(block_table[j // block_size])
|
| 84 |
+
block_offset = j % block_size
|
| 85 |
+
|
| 86 |
+
k = key_cache[block_number, :, :, block_offset, :]
|
| 87 |
+
k = k.reshape(num_kv_heads, head_size)
|
| 88 |
+
keys_lst.append(k)
|
| 89 |
+
|
| 90 |
+
v = value_cache[block_number, :, :, block_offset]
|
| 91 |
+
values_lst.append(v)
|
| 92 |
+
keys = torch.stack(keys_lst, dim=0)
|
| 93 |
+
values = torch.stack(values_lst, dim=0)
|
| 94 |
+
if num_queries_per_kv > 1:
|
| 95 |
+
# Handle MQA and GQA
|
| 96 |
+
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
| 97 |
+
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
| 98 |
+
|
| 99 |
+
alibi_bias = None
|
| 100 |
+
if alibi_slopes is not None:
|
| 101 |
+
# Create the ALiBi bias used in the paged attention kernel.
|
| 102 |
+
position_ids = torch.arange(seq_len).int()
|
| 103 |
+
alibi_bias = (position_ids - seq_len + 1).float()
|
| 104 |
+
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
|
| 105 |
+
|
| 106 |
+
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
| 107 |
+
out = out.view(num_query_heads, head_size)
|
| 108 |
+
output[i].copy_(out, non_blocking=True)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@pytest.mark.parametrize(
|
| 112 |
+
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
|
| 113 |
+
)
|
| 114 |
+
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
| 115 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 116 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 117 |
+
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
| 118 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 119 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 120 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 121 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 122 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 123 |
+
def test_paged_attention(
|
| 124 |
+
kv_cache_factory,
|
| 125 |
+
version: str,
|
| 126 |
+
num_seqs: int,
|
| 127 |
+
num_heads: Tuple[int, int],
|
| 128 |
+
head_size: int,
|
| 129 |
+
use_alibi: bool,
|
| 130 |
+
block_size: int,
|
| 131 |
+
dtype: torch.dtype,
|
| 132 |
+
kv_cache_dtype: str,
|
| 133 |
+
seed: int,
|
| 134 |
+
device: str,
|
| 135 |
+
) -> None:
|
| 136 |
+
if (kv_cache_dtype == "fp8" and head_size % 16) or (
|
| 137 |
+
version == "rocm" and head_size not in (64, 128)
|
| 138 |
+
):
|
| 139 |
+
pytest.skip()
|
| 140 |
+
|
| 141 |
+
current_platform.seed_everything(seed)
|
| 142 |
+
torch.set_default_device(device)
|
| 143 |
+
scale = float(1.0 / (head_size**0.5))
|
| 144 |
+
num_query_heads, num_kv_heads = num_heads
|
| 145 |
+
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
| 146 |
+
query.uniform_(-scale, scale)
|
| 147 |
+
|
| 148 |
+
assert num_query_heads % num_kv_heads == 0
|
| 149 |
+
num_queries_per_kv = num_query_heads // num_kv_heads
|
| 150 |
+
alibi_slopes = None
|
| 151 |
+
if use_alibi:
|
| 152 |
+
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
| 153 |
+
|
| 154 |
+
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
| 155 |
+
seq_lens[-1] = MAX_SEQ_LEN
|
| 156 |
+
max_seq_len = max(seq_lens)
|
| 157 |
+
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
|
| 158 |
+
|
| 159 |
+
# Create the block tables.
|
| 160 |
+
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
| 161 |
+
block_tables_lst: List[List[int]] = []
|
| 162 |
+
for _ in range(num_seqs):
|
| 163 |
+
block_table = [
|
| 164 |
+
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
| 165 |
+
]
|
| 166 |
+
block_tables_lst.append(block_table)
|
| 167 |
+
|
| 168 |
+
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
|
| 169 |
+
|
| 170 |
+
# Create the KV caches.
|
| 171 |
+
key_caches, value_caches = kv_cache_factory(
|
| 172 |
+
NUM_BLOCKS,
|
| 173 |
+
block_size,
|
| 174 |
+
1,
|
| 175 |
+
num_kv_heads,
|
| 176 |
+
head_size,
|
| 177 |
+
kv_cache_dtype,
|
| 178 |
+
dtype,
|
| 179 |
+
seed,
|
| 180 |
+
device,
|
| 181 |
+
)
|
| 182 |
+
key_cache, value_cache = key_caches[0], value_caches[0]
|
| 183 |
+
|
| 184 |
+
# Using default kv_scale
|
| 185 |
+
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
| 186 |
+
|
| 187 |
+
# Call the paged attention kernel.
|
| 188 |
+
output = torch.empty_like(query)
|
| 189 |
+
if version == "v1":
|
| 190 |
+
ops.paged_attention_v1(
|
| 191 |
+
output,
|
| 192 |
+
query,
|
| 193 |
+
key_cache,
|
| 194 |
+
value_cache,
|
| 195 |
+
num_kv_heads,
|
| 196 |
+
scale,
|
| 197 |
+
block_tables,
|
| 198 |
+
seq_lens,
|
| 199 |
+
block_size,
|
| 200 |
+
max_seq_len,
|
| 201 |
+
alibi_slopes,
|
| 202 |
+
kv_cache_dtype,
|
| 203 |
+
k_scale,
|
| 204 |
+
v_scale,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
opcheck(
|
| 208 |
+
ops.ops.paged_attention_v1,
|
| 209 |
+
(
|
| 210 |
+
output,
|
| 211 |
+
query,
|
| 212 |
+
key_cache,
|
| 213 |
+
value_cache,
|
| 214 |
+
num_kv_heads,
|
| 215 |
+
scale,
|
| 216 |
+
block_tables,
|
| 217 |
+
seq_lens,
|
| 218 |
+
block_size,
|
| 219 |
+
max_seq_len,
|
| 220 |
+
alibi_slopes,
|
| 221 |
+
kv_cache_dtype,
|
| 222 |
+
k_scale,
|
| 223 |
+
v_scale,
|
| 224 |
+
0,
|
| 225 |
+
0,
|
| 226 |
+
0,
|
| 227 |
+
64,
|
| 228 |
+
0,
|
| 229 |
+
),
|
| 230 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
elif version in ("v2", "rocm"):
|
| 234 |
+
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
| 235 |
+
assert PARTITION_SIZE % block_size == 0
|
| 236 |
+
num_seqs, num_heads, head_size = output.shape
|
| 237 |
+
tmp_output = torch.empty(
|
| 238 |
+
size=(num_seqs, num_heads, num_partitions, head_size),
|
| 239 |
+
dtype=output.dtype,
|
| 240 |
+
)
|
| 241 |
+
exp_sums = torch.empty(
|
| 242 |
+
size=(num_seqs, num_heads, num_partitions),
|
| 243 |
+
dtype=torch.float32,
|
| 244 |
+
)
|
| 245 |
+
max_logits = torch.empty_like(exp_sums)
|
| 246 |
+
if version == "v2":
|
| 247 |
+
ops.paged_attention_v2(
|
| 248 |
+
output,
|
| 249 |
+
exp_sums,
|
| 250 |
+
max_logits,
|
| 251 |
+
tmp_output,
|
| 252 |
+
query,
|
| 253 |
+
key_cache,
|
| 254 |
+
value_cache,
|
| 255 |
+
num_kv_heads,
|
| 256 |
+
scale,
|
| 257 |
+
block_tables,
|
| 258 |
+
seq_lens,
|
| 259 |
+
block_size,
|
| 260 |
+
max_seq_len,
|
| 261 |
+
alibi_slopes,
|
| 262 |
+
kv_cache_dtype,
|
| 263 |
+
k_scale,
|
| 264 |
+
v_scale,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
opcheck(
|
| 268 |
+
ops.ops.paged_attention_v2,
|
| 269 |
+
(
|
| 270 |
+
output,
|
| 271 |
+
exp_sums,
|
| 272 |
+
max_logits,
|
| 273 |
+
tmp_output,
|
| 274 |
+
query,
|
| 275 |
+
key_cache,
|
| 276 |
+
value_cache,
|
| 277 |
+
num_kv_heads,
|
| 278 |
+
scale,
|
| 279 |
+
block_tables,
|
| 280 |
+
seq_lens,
|
| 281 |
+
block_size,
|
| 282 |
+
max_seq_len,
|
| 283 |
+
alibi_slopes,
|
| 284 |
+
kv_cache_dtype,
|
| 285 |
+
k_scale,
|
| 286 |
+
v_scale,
|
| 287 |
+
0,
|
| 288 |
+
0,
|
| 289 |
+
0,
|
| 290 |
+
64,
|
| 291 |
+
0,
|
| 292 |
+
),
|
| 293 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
else:
|
| 297 |
+
ops.paged_attention_rocm(
|
| 298 |
+
output,
|
| 299 |
+
exp_sums,
|
| 300 |
+
max_logits,
|
| 301 |
+
tmp_output,
|
| 302 |
+
query,
|
| 303 |
+
key_cache,
|
| 304 |
+
value_cache,
|
| 305 |
+
num_kv_heads,
|
| 306 |
+
scale,
|
| 307 |
+
block_tables,
|
| 308 |
+
seq_lens,
|
| 309 |
+
block_size,
|
| 310 |
+
max_seq_len,
|
| 311 |
+
alibi_slopes,
|
| 312 |
+
kv_cache_dtype,
|
| 313 |
+
k_scale,
|
| 314 |
+
v_scale,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
opcheck(
|
| 318 |
+
torch.ops._rocm_C.paged_attention,
|
| 319 |
+
(
|
| 320 |
+
output,
|
| 321 |
+
exp_sums,
|
| 322 |
+
max_logits,
|
| 323 |
+
tmp_output,
|
| 324 |
+
query,
|
| 325 |
+
key_cache,
|
| 326 |
+
value_cache,
|
| 327 |
+
num_kv_heads,
|
| 328 |
+
scale,
|
| 329 |
+
block_tables,
|
| 330 |
+
seq_lens,
|
| 331 |
+
block_size,
|
| 332 |
+
max_seq_len,
|
| 333 |
+
alibi_slopes,
|
| 334 |
+
kv_cache_dtype,
|
| 335 |
+
k_scale,
|
| 336 |
+
v_scale,
|
| 337 |
+
),
|
| 338 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
else:
|
| 342 |
+
raise AssertionError(f"Unknown version: {version}")
|
| 343 |
+
|
| 344 |
+
# Run the reference implementation.
|
| 345 |
+
if kv_cache_dtype == "fp8":
|
| 346 |
+
# Convert cache data back to dtype.
|
| 347 |
+
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
| 348 |
+
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
| 349 |
+
dequantized_key_cache = torch.empty(
|
| 350 |
+
size=key_cache_shape, dtype=dtype, device=device
|
| 351 |
+
)
|
| 352 |
+
ops.convert_fp8(dequantized_key_cache, key_cache)
|
| 353 |
+
key_cache = dequantized_key_cache
|
| 354 |
+
|
| 355 |
+
value_cache_shape = value_cache.shape
|
| 356 |
+
dequantized_value_cache = torch.empty(
|
| 357 |
+
size=value_cache_shape, dtype=dtype, device=device
|
| 358 |
+
)
|
| 359 |
+
ops.convert_fp8(dequantized_value_cache, value_cache)
|
| 360 |
+
value_cache = dequantized_value_cache
|
| 361 |
+
|
| 362 |
+
ref_output = torch.empty_like(query)
|
| 363 |
+
ref_single_query_cached_kv_attention(
|
| 364 |
+
ref_output,
|
| 365 |
+
query,
|
| 366 |
+
num_queries_per_kv,
|
| 367 |
+
key_cache,
|
| 368 |
+
value_cache,
|
| 369 |
+
block_tables,
|
| 370 |
+
seq_lens,
|
| 371 |
+
scale,
|
| 372 |
+
alibi_slopes,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# NOTE(woosuk): Due to the kernel-level differences in the two
|
| 376 |
+
# implementations, there is a small numerical difference in the two
|
| 377 |
+
# outputs. Thus, we use a relaxed tolerance for the test.
|
| 378 |
+
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
| 379 |
+
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
| 380 |
+
|
| 381 |
+
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
| 382 |
+
# so we use a relaxed tolerance for the test.
|
| 383 |
+
atol, rtol = 1e-3, 1e-5
|
| 384 |
+
if kv_cache_dtype == "fp8":
|
| 385 |
+
atol, rtol = 1e-2, 1e-5
|
| 386 |
+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def ref_multi_query_kv_attention(
|
| 390 |
+
cu_seq_lens: List[int],
|
| 391 |
+
query: torch.Tensor,
|
| 392 |
+
key: torch.Tensor,
|
| 393 |
+
value: torch.Tensor,
|
| 394 |
+
scale: float,
|
| 395 |
+
dtype: torch.dtype,
|
| 396 |
+
) -> torch.Tensor:
|
| 397 |
+
num_seqs = len(cu_seq_lens) - 1
|
| 398 |
+
ref_outputs: List[torch.Tensor] = []
|
| 399 |
+
for i in range(num_seqs):
|
| 400 |
+
start_idx = cu_seq_lens[i]
|
| 401 |
+
end_idx = cu_seq_lens[i + 1]
|
| 402 |
+
seq_len = end_idx - start_idx
|
| 403 |
+
|
| 404 |
+
# Create attention mask.
|
| 405 |
+
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
|
| 406 |
+
attn_mask = attn_mask * torch.finfo(dtype).min
|
| 407 |
+
attn_mask = attn_mask.to(dtype=dtype)
|
| 408 |
+
|
| 409 |
+
ref_output = ref_masked_attention(
|
| 410 |
+
query[start_idx:end_idx],
|
| 411 |
+
key[start_idx:end_idx],
|
| 412 |
+
value[start_idx:end_idx],
|
| 413 |
+
scale,
|
| 414 |
+
attn_mask=attn_mask,
|
| 415 |
+
)
|
| 416 |
+
ref_outputs.append(ref_output)
|
| 417 |
+
|
| 418 |
+
return torch.cat(ref_outputs, dim=0)
|
tests/kernels/test_cache.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
import paged_attention as ops
|
| 5 |
+
import pytest
|
| 6 |
+
import torch
|
| 7 |
+
from paged_attention.platforms import current_platform
|
| 8 |
+
|
| 9 |
+
from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
| 10 |
+
|
| 11 |
+
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
|
| 12 |
+
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
| 13 |
+
NUM_TOKENS = [42] # Arbitrary values for testing
|
| 14 |
+
NUM_LAYERS = [1] # Arbitrary values for testing
|
| 15 |
+
NUM_HEADS = [8] # Arbitrary values for testing
|
| 16 |
+
HEAD_SIZES = [64, 80, 120, 256]
|
| 17 |
+
BLOCK_SIZES = [8, 16, 32]
|
| 18 |
+
|
| 19 |
+
# Arbitrary values for testing
|
| 20 |
+
# don't make it too large. e.g. [1024, 36000] will OOM
|
| 21 |
+
NUM_BLOCKS = [1024, 10000]
|
| 22 |
+
|
| 23 |
+
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
| 24 |
+
SEEDS = [0]
|
| 25 |
+
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
| 26 |
+
|
| 27 |
+
# We assume fp8 is always enabled for testing.
|
| 28 |
+
KV_CACHE_DTYPE = ["auto", "fp8"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
| 32 |
+
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
| 33 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 34 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 35 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 36 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 37 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 38 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 39 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 40 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 41 |
+
@torch.inference_mode()
|
| 42 |
+
def test_copy_blocks(
|
| 43 |
+
kv_cache_factory,
|
| 44 |
+
num_mappings: int,
|
| 45 |
+
num_layers: int,
|
| 46 |
+
num_heads: int,
|
| 47 |
+
head_size: int,
|
| 48 |
+
block_size: int,
|
| 49 |
+
num_blocks: int,
|
| 50 |
+
dtype: torch.dtype,
|
| 51 |
+
seed: int,
|
| 52 |
+
kv_cache_dtype: str,
|
| 53 |
+
device: str,
|
| 54 |
+
) -> None:
|
| 55 |
+
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 56 |
+
pytest.skip()
|
| 57 |
+
current_platform.seed_everything(seed)
|
| 58 |
+
torch.set_default_device(device)
|
| 59 |
+
# Generate random block mappings where each source block is mapped to two
|
| 60 |
+
# destination blocks.
|
| 61 |
+
assert 2 * num_mappings <= num_blocks
|
| 62 |
+
src_blocks = random.sample(range(num_blocks), num_mappings)
|
| 63 |
+
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
| 64 |
+
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
| 65 |
+
block_mapping: List[Tuple[int, int]] = []
|
| 66 |
+
for i in range(num_mappings):
|
| 67 |
+
src = src_blocks[i]
|
| 68 |
+
dst1 = dst_blocks[2 * i]
|
| 69 |
+
dst2 = dst_blocks[2 * i + 1]
|
| 70 |
+
block_mapping.append((src, dst1))
|
| 71 |
+
block_mapping.append((src, dst2))
|
| 72 |
+
|
| 73 |
+
# Create the KV caches.
|
| 74 |
+
key_caches, value_caches = kv_cache_factory(
|
| 75 |
+
num_blocks,
|
| 76 |
+
block_size,
|
| 77 |
+
num_layers,
|
| 78 |
+
num_heads,
|
| 79 |
+
head_size,
|
| 80 |
+
kv_cache_dtype,
|
| 81 |
+
dtype,
|
| 82 |
+
seed,
|
| 83 |
+
device,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Clone the KV caches.
|
| 87 |
+
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
| 88 |
+
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
| 89 |
+
|
| 90 |
+
# Call the copy blocks kernel.
|
| 91 |
+
block_mapping_tensor = torch.tensor(
|
| 92 |
+
block_mapping, dtype=torch.int64, device=device
|
| 93 |
+
).view(-1, 2)
|
| 94 |
+
|
| 95 |
+
opcheck(
|
| 96 |
+
ops.ops.copy_blocks,
|
| 97 |
+
(key_caches, value_caches, block_mapping_tensor),
|
| 98 |
+
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
| 99 |
+
cond=(head_size == HEAD_SIZES[0]),
|
| 100 |
+
)
|
| 101 |
+
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
|
| 102 |
+
|
| 103 |
+
# Run the reference implementation.
|
| 104 |
+
for src, dst in block_mapping:
|
| 105 |
+
for cloned_key_cache in cloned_key_caches:
|
| 106 |
+
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
| 107 |
+
for cloned_value_cache in cloned_value_caches:
|
| 108 |
+
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
| 109 |
+
|
| 110 |
+
# Compare the results.
|
| 111 |
+
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
| 112 |
+
torch.testing.assert_close(key_cache, cloned_key_cache)
|
| 113 |
+
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
| 114 |
+
torch.testing.assert_close(value_cache, cloned_value_cache)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 118 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 119 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 120 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 121 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 122 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 123 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 124 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 125 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 126 |
+
@torch.inference_mode()
|
| 127 |
+
def test_reshape_and_cache(
|
| 128 |
+
kv_cache_factory,
|
| 129 |
+
num_tokens: int,
|
| 130 |
+
num_heads: int,
|
| 131 |
+
head_size: int,
|
| 132 |
+
block_size: int,
|
| 133 |
+
num_blocks: int,
|
| 134 |
+
dtype: torch.dtype,
|
| 135 |
+
seed: int,
|
| 136 |
+
device: str,
|
| 137 |
+
kv_cache_dtype: str,
|
| 138 |
+
) -> None:
|
| 139 |
+
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 140 |
+
pytest.skip()
|
| 141 |
+
current_platform.seed_everything(seed)
|
| 142 |
+
torch.set_default_device(device)
|
| 143 |
+
# Create a random slot mapping.
|
| 144 |
+
num_slots = block_size * num_blocks
|
| 145 |
+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
| 146 |
+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
|
| 147 |
+
|
| 148 |
+
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
|
| 149 |
+
_, key, value = qkv.unbind(dim=1)
|
| 150 |
+
|
| 151 |
+
# Create the KV caches.
|
| 152 |
+
key_caches, value_caches = kv_cache_factory(
|
| 153 |
+
num_blocks,
|
| 154 |
+
block_size,
|
| 155 |
+
1,
|
| 156 |
+
num_heads,
|
| 157 |
+
head_size,
|
| 158 |
+
kv_cache_dtype,
|
| 159 |
+
dtype,
|
| 160 |
+
seed,
|
| 161 |
+
device,
|
| 162 |
+
)
|
| 163 |
+
key_cache, value_cache = key_caches[0], value_caches[0]
|
| 164 |
+
|
| 165 |
+
# Clone the KV caches.
|
| 166 |
+
if kv_cache_dtype == "fp8":
|
| 167 |
+
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
| 168 |
+
ops.convert_fp8(cloned_key_cache, key_cache)
|
| 169 |
+
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
| 170 |
+
ops.convert_fp8(cloned_value_cache, value_cache)
|
| 171 |
+
else:
|
| 172 |
+
cloned_key_cache = key_cache.clone()
|
| 173 |
+
cloned_value_cache = value_cache.clone()
|
| 174 |
+
|
| 175 |
+
# Using default kv_scale
|
| 176 |
+
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
| 177 |
+
|
| 178 |
+
# Call the reshape_and_cache kernel.
|
| 179 |
+
opcheck(
|
| 180 |
+
ops.ops.reshape_and_cache,
|
| 181 |
+
(
|
| 182 |
+
key,
|
| 183 |
+
value,
|
| 184 |
+
key_cache,
|
| 185 |
+
value_cache,
|
| 186 |
+
slot_mapping,
|
| 187 |
+
kv_cache_dtype,
|
| 188 |
+
k_scale,
|
| 189 |
+
v_scale,
|
| 190 |
+
),
|
| 191 |
+
cond=(head_size == HEAD_SIZES[0]),
|
| 192 |
+
)
|
| 193 |
+
ops.reshape_and_cache(
|
| 194 |
+
key,
|
| 195 |
+
value,
|
| 196 |
+
key_cache,
|
| 197 |
+
value_cache,
|
| 198 |
+
slot_mapping,
|
| 199 |
+
kv_cache_dtype,
|
| 200 |
+
k_scale,
|
| 201 |
+
v_scale,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if kv_cache_dtype == "fp8":
|
| 205 |
+
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
| 206 |
+
ops.convert_fp8(result_key_cache, key_cache)
|
| 207 |
+
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
| 208 |
+
ops.convert_fp8(result_value_cache, value_cache)
|
| 209 |
+
|
| 210 |
+
# Run the reference implementation.
|
| 211 |
+
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
| 212 |
+
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
| 213 |
+
block_indicies_lst = block_indicies.cpu().tolist()
|
| 214 |
+
block_offsets = slot_mapping % block_size
|
| 215 |
+
block_offsets_lst = block_offsets.cpu().tolist()
|
| 216 |
+
for i in range(num_tokens):
|
| 217 |
+
block_idx = block_indicies_lst[i]
|
| 218 |
+
block_offset = block_offsets_lst[i]
|
| 219 |
+
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
| 220 |
+
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
| 221 |
+
|
| 222 |
+
if kv_cache_dtype == "fp8":
|
| 223 |
+
torch.testing.assert_close(
|
| 224 |
+
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
|
| 225 |
+
)
|
| 226 |
+
torch.testing.assert_close(
|
| 227 |
+
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
torch.testing.assert_close(key_cache, cloned_key_cache)
|
| 231 |
+
torch.testing.assert_close(value_cache, cloned_value_cache)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 235 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 236 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 237 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 238 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 239 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 240 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 241 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 242 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 243 |
+
@torch.inference_mode()
|
| 244 |
+
def test_reshape_and_cache_flash(
|
| 245 |
+
kv_cache_factory_flashinfer,
|
| 246 |
+
num_tokens: int,
|
| 247 |
+
num_heads: int,
|
| 248 |
+
head_size: int,
|
| 249 |
+
block_size: int,
|
| 250 |
+
num_blocks: int,
|
| 251 |
+
dtype: torch.dtype,
|
| 252 |
+
seed: int,
|
| 253 |
+
device: str,
|
| 254 |
+
kv_cache_dtype: str,
|
| 255 |
+
) -> None:
|
| 256 |
+
current_platform.seed_everything(seed)
|
| 257 |
+
torch.set_default_device(device)
|
| 258 |
+
|
| 259 |
+
# Create a random slot mapping.
|
| 260 |
+
num_slots = block_size * num_blocks
|
| 261 |
+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
| 262 |
+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
| 263 |
+
|
| 264 |
+
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
|
| 265 |
+
_, key, value = qkv.unbind(dim=1)
|
| 266 |
+
|
| 267 |
+
# Create the KV caches.
|
| 268 |
+
key_caches, value_caches = kv_cache_factory_flashinfer(
|
| 269 |
+
num_blocks,
|
| 270 |
+
block_size,
|
| 271 |
+
1,
|
| 272 |
+
num_heads,
|
| 273 |
+
head_size,
|
| 274 |
+
kv_cache_dtype,
|
| 275 |
+
dtype,
|
| 276 |
+
device=device,
|
| 277 |
+
)
|
| 278 |
+
key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous()
|
| 279 |
+
del key_caches
|
| 280 |
+
del value_caches
|
| 281 |
+
|
| 282 |
+
k_scale = (key.amax() / 256.0).to(torch.float32)
|
| 283 |
+
v_scale = (value.amax() / 256.0).to(torch.float32)
|
| 284 |
+
|
| 285 |
+
# Clone the KV caches.
|
| 286 |
+
if kv_cache_dtype == "fp8":
|
| 287 |
+
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
| 288 |
+
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
|
| 289 |
+
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
| 290 |
+
ops.convert_fp8(cloned_value_cache, value_cache, v_scale, kv_cache_dtype)
|
| 291 |
+
else:
|
| 292 |
+
cloned_key_cache = key_cache.clone()
|
| 293 |
+
cloned_value_cache = value_cache.clone()
|
| 294 |
+
|
| 295 |
+
# Call the reshape_and_cache kernel.
|
| 296 |
+
opcheck(
|
| 297 |
+
ops.ops.reshape_and_cache_flash,
|
| 298 |
+
(
|
| 299 |
+
key,
|
| 300 |
+
value,
|
| 301 |
+
key_cache,
|
| 302 |
+
value_cache,
|
| 303 |
+
slot_mapping,
|
| 304 |
+
kv_cache_dtype,
|
| 305 |
+
k_scale,
|
| 306 |
+
v_scale,
|
| 307 |
+
),
|
| 308 |
+
cond=(head_size == HEAD_SIZES[0]),
|
| 309 |
+
)
|
| 310 |
+
ops.reshape_and_cache_flash(
|
| 311 |
+
key,
|
| 312 |
+
value,
|
| 313 |
+
key_cache,
|
| 314 |
+
value_cache,
|
| 315 |
+
slot_mapping,
|
| 316 |
+
kv_cache_dtype,
|
| 317 |
+
k_scale,
|
| 318 |
+
v_scale,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if kv_cache_dtype == "fp8":
|
| 322 |
+
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
| 323 |
+
ops.convert_fp8(
|
| 324 |
+
result_key_cache, key_cache, k_scale.item(), kv_dtype=kv_cache_dtype
|
| 325 |
+
)
|
| 326 |
+
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
| 327 |
+
ops.convert_fp8(
|
| 328 |
+
result_value_cache, value_cache, v_scale.item(), kv_dtype=kv_cache_dtype
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Run the reference implementation.
|
| 332 |
+
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
| 333 |
+
block_indicies_lst = block_indicies.cpu().tolist()
|
| 334 |
+
block_offsets = slot_mapping % block_size
|
| 335 |
+
block_offsets_lst = block_offsets.cpu().tolist()
|
| 336 |
+
for i in range(num_tokens):
|
| 337 |
+
block_idx = block_indicies_lst[i]
|
| 338 |
+
block_offset = block_offsets_lst[i]
|
| 339 |
+
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
| 340 |
+
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
| 341 |
+
|
| 342 |
+
if kv_cache_dtype == "fp8":
|
| 343 |
+
torch.testing.assert_close(
|
| 344 |
+
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
|
| 345 |
+
)
|
| 346 |
+
torch.testing.assert_close(
|
| 347 |
+
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
torch.testing.assert_close(key_cache, cloned_key_cache)
|
| 351 |
+
torch.testing.assert_close(value_cache, cloned_value_cache)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
| 355 |
+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
| 356 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 357 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 358 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 359 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 360 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 361 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 362 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 363 |
+
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
| 364 |
+
@torch.inference_mode()
|
| 365 |
+
def test_swap_blocks(
|
| 366 |
+
kv_cache_factory,
|
| 367 |
+
direction: Tuple[str, str],
|
| 368 |
+
num_mappings: int,
|
| 369 |
+
num_heads: int,
|
| 370 |
+
head_size: int,
|
| 371 |
+
block_size: int,
|
| 372 |
+
num_blocks: int,
|
| 373 |
+
dtype: torch.dtype,
|
| 374 |
+
seed: int,
|
| 375 |
+
device: str,
|
| 376 |
+
kv_cache_dtype: str,
|
| 377 |
+
) -> None:
|
| 378 |
+
if kv_cache_dtype == "fp8" and "cpu" in direction:
|
| 379 |
+
pytest.skip()
|
| 380 |
+
if kv_cache_dtype == "fp8" and head_size % 16:
|
| 381 |
+
pytest.skip()
|
| 382 |
+
|
| 383 |
+
current_platform.seed_everything(seed)
|
| 384 |
+
|
| 385 |
+
src_device = device if direction[0] == "cuda" else "cpu"
|
| 386 |
+
dst_device = device if direction[1] == "cuda" else "cpu"
|
| 387 |
+
|
| 388 |
+
src_blocks = random.sample(range(num_blocks), num_mappings)
|
| 389 |
+
# For the same device, mapping must not overlap
|
| 390 |
+
if src_device == dst_device:
|
| 391 |
+
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
| 392 |
+
dst_blocks = random.sample(remaining_blocks, num_mappings)
|
| 393 |
+
else:
|
| 394 |
+
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
| 395 |
+
|
| 396 |
+
block_mapping = list(zip(src_blocks, dst_blocks))
|
| 397 |
+
block_mapping_tensor = torch.tensor(
|
| 398 |
+
block_mapping, dtype=torch.int64, device="cpu"
|
| 399 |
+
).view(-1, 2)
|
| 400 |
+
|
| 401 |
+
# Create the KV caches on the first device.
|
| 402 |
+
src_key_caches, src_value_caches = kv_cache_factory(
|
| 403 |
+
num_blocks,
|
| 404 |
+
block_size,
|
| 405 |
+
1,
|
| 406 |
+
num_heads,
|
| 407 |
+
head_size,
|
| 408 |
+
kv_cache_dtype,
|
| 409 |
+
dtype,
|
| 410 |
+
seed,
|
| 411 |
+
src_device,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# Create the KV caches on the second device.
|
| 415 |
+
dist_key_caches, dist_value_caches = kv_cache_factory(
|
| 416 |
+
num_blocks,
|
| 417 |
+
block_size,
|
| 418 |
+
1,
|
| 419 |
+
num_heads,
|
| 420 |
+
head_size,
|
| 421 |
+
kv_cache_dtype,
|
| 422 |
+
dtype,
|
| 423 |
+
seed,
|
| 424 |
+
dst_device,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
src_key_caches_clone = src_key_caches[0].clone()
|
| 428 |
+
src_value_caches_clone = src_value_caches[0].clone()
|
| 429 |
+
|
| 430 |
+
# Call the swap_blocks kernel.
|
| 431 |
+
do_opcheck = head_size == HEAD_SIZES[0]
|
| 432 |
+
opcheck(
|
| 433 |
+
ops.ops.swap_blocks,
|
| 434 |
+
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
|
| 435 |
+
cond=do_opcheck,
|
| 436 |
+
)
|
| 437 |
+
opcheck(
|
| 438 |
+
ops.ops.swap_blocks,
|
| 439 |
+
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
|
| 440 |
+
cond=do_opcheck,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
|
| 444 |
+
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
|
| 445 |
+
|
| 446 |
+
for src, dst in block_mapping:
|
| 447 |
+
torch.testing.assert_close(
|
| 448 |
+
src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
|
| 449 |
+
)
|
| 450 |
+
torch.testing.assert_close(
|
| 451 |
+
src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 456 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 457 |
+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
| 458 |
+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
| 459 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 460 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 461 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 462 |
+
@torch.inference_mode()
|
| 463 |
+
def test_fp8_e4m3_conversion(
|
| 464 |
+
num_heads: int,
|
| 465 |
+
head_size: int,
|
| 466 |
+
block_size: int,
|
| 467 |
+
num_blocks: int,
|
| 468 |
+
dtype: torch.dtype,
|
| 469 |
+
seed: int,
|
| 470 |
+
device: str,
|
| 471 |
+
) -> None:
|
| 472 |
+
current_platform.seed_everything(seed)
|
| 473 |
+
|
| 474 |
+
low = -224.0
|
| 475 |
+
high = 224.0
|
| 476 |
+
shape = (num_blocks, num_heads, head_size, block_size)
|
| 477 |
+
cache = torch.empty(shape, dtype=dtype, device=device)
|
| 478 |
+
cache.uniform_(low, high)
|
| 479 |
+
|
| 480 |
+
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
| 481 |
+
ops.convert_fp8(cache_fp8, cache)
|
| 482 |
+
|
| 483 |
+
converted_cache = torch.empty_like(cache)
|
| 484 |
+
ops.convert_fp8(converted_cache, cache_fp8)
|
| 485 |
+
|
| 486 |
+
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
|
tests/kernels/utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Kernel test utils"""
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
import random
|
| 5 |
+
import unittest
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from numbers import Number
|
| 8 |
+
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
import torch
|
| 12 |
+
from torch._prims_common import TensorLikeType
|
| 13 |
+
|
| 14 |
+
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
| 15 |
+
# bugs related to this test in PyTorch 2.4.
|
| 16 |
+
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
| 17 |
+
"test_schema",
|
| 18 |
+
"test_autograd_registration",
|
| 19 |
+
"test_faketensor",
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
| 23 |
+
"test_schema",
|
| 24 |
+
"test_autograd_registration",
|
| 25 |
+
"test_faketensor",
|
| 26 |
+
"test_aot_dispatch_dynamic",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Copied/modified from torch._refs.__init__.py
|
| 31 |
+
def fp8_allclose(
|
| 32 |
+
a: TensorLikeType,
|
| 33 |
+
b: TensorLikeType,
|
| 34 |
+
rtol: float = 1e-05,
|
| 35 |
+
atol: float = 1e-08,
|
| 36 |
+
equal_nan: bool = False,
|
| 37 |
+
) -> bool:
|
| 38 |
+
"""
|
| 39 |
+
Reference implementation of torch.allclose
|
| 40 |
+
"""
|
| 41 |
+
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
| 42 |
+
|
| 43 |
+
return bool(
|
| 44 |
+
torch.all(
|
| 45 |
+
torch.isclose(
|
| 46 |
+
a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
|
| 47 |
+
)
|
| 48 |
+
).item()
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def compute_max_diff(output, output_ref):
|
| 53 |
+
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
| 54 |
+
torch.abs(output_ref)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# A special version of op check that has a restricted default set of test_utils
|
| 59 |
+
# and a patched version of allclose that supports fp8 types.
|
| 60 |
+
def opcheck(
|
| 61 |
+
op: Union[
|
| 62 |
+
torch._ops.OpOverload,
|
| 63 |
+
torch._ops.OpOverloadPacket,
|
| 64 |
+
torch._library.custom_ops.CustomOpDef,
|
| 65 |
+
],
|
| 66 |
+
args: Tuple[Any, ...],
|
| 67 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 68 |
+
*,
|
| 69 |
+
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
| 70 |
+
raise_exception: bool = True,
|
| 71 |
+
cond: bool = True
|
| 72 |
+
) -> Dict[str, str]:
|
| 73 |
+
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
| 74 |
+
return (
|
| 75 |
+
torch.library.opcheck(
|
| 76 |
+
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
| 77 |
+
)
|
| 78 |
+
if cond
|
| 79 |
+
else {}
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@lru_cache(maxsize=None)
|
| 84 |
+
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
| 85 |
+
"""Returns the maximum shared memory per thread block in bytes."""
|
| 86 |
+
from paged_attention import ops
|
| 87 |
+
|
| 88 |
+
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
| 89 |
+
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
| 90 |
+
# will fail
|
| 91 |
+
assert max_shared_mem > 0, "max_shared_mem can not be zero"
|
| 92 |
+
return int(max_shared_mem)
|