|
import random |
|
from typing import List, Optional, Tuple |
|
|
|
import paged_attention as ops |
|
import pytest |
|
import torch |
|
from paged_attention.platforms import current_platform |
|
|
|
from .allclose_default import get_default_atol, get_default_rtol |
|
from .utils import get_max_shared_memory_bytes, opcheck |
|
|
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 |
|
|
|
|
|
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 |
|
|
|
|
|
NUM_BLOCKS = 4321 |
|
PARTITION_SIZE = 512 |
|
|
|
DTYPES = ( |
|
[torch.half, torch.bfloat16, torch.float] |
|
if not current_platform.is_rocm() |
|
else [torch.half, torch.bfloat16] |
|
) |
|
NUM_GEN_SEQS = [7] |
|
NUM_PREFILL_SEQS = [3] |
|
NUM_HEADS = [(40, 40), (64, 8)] |
|
|
|
|
|
|
|
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256] |
|
|
|
BLOCK_SIZES = [16, 32] |
|
USE_ALIBI = [False, True] |
|
KV_CACHE_DTYPE = ["auto", "fp8"] |
|
SEEDS = [0] |
|
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] |
|
|
|
|
|
def ref_masked_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
scale: float, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() |
|
if attn_mask is not None: |
|
attn_weights = attn_weights + attn_mask.float() |
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) |
|
out = torch.einsum("hqk,khd->qhd", attn_weights, value) |
|
return out |
|
|
|
|
|
def ref_single_query_cached_kv_attention( |
|
output: torch.Tensor, |
|
query: torch.Tensor, |
|
num_queries_per_kv: int, |
|
key_cache: torch.Tensor, |
|
value_cache: torch.Tensor, |
|
block_tables: torch.Tensor, |
|
seq_lens: torch.Tensor, |
|
scale: float, |
|
alibi_slopes: Optional[torch.Tensor], |
|
) -> None: |
|
num_query_heads = query.shape[1] |
|
num_kv_heads = value_cache.shape[1] |
|
head_size = value_cache.shape[2] |
|
block_size = value_cache.shape[3] |
|
num_seqs = query.shape[0] |
|
|
|
block_tables_lst = block_tables.cpu().tolist() |
|
seq_lens_lst = seq_lens.cpu().tolist() |
|
for i in range(num_seqs): |
|
q = query[i].unsqueeze(0) |
|
block_table = block_tables_lst[i] |
|
seq_len = int(seq_lens_lst[i]) |
|
|
|
keys_lst: List[torch.Tensor] = [] |
|
values_lst: List[torch.Tensor] = [] |
|
for j in range(seq_len): |
|
block_number = int(block_table[j // block_size]) |
|
block_offset = j % block_size |
|
|
|
k = key_cache[block_number, :, :, block_offset, :] |
|
k = k.reshape(num_kv_heads, head_size) |
|
keys_lst.append(k) |
|
|
|
v = value_cache[block_number, :, :, block_offset] |
|
values_lst.append(v) |
|
keys = torch.stack(keys_lst, dim=0) |
|
values = torch.stack(values_lst, dim=0) |
|
if num_queries_per_kv > 1: |
|
|
|
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) |
|
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) |
|
|
|
alibi_bias = None |
|
if alibi_slopes is not None: |
|
|
|
position_ids = torch.arange(seq_len).int() |
|
alibi_bias = (position_ids - seq_len + 1).float() |
|
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) |
|
|
|
out = ref_masked_attention(q, keys, values, scale, alibi_bias) |
|
out = out.view(num_query_heads, head_size) |
|
output[i].copy_(out, non_blocking=True) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"] |
|
) |
|
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) |
|
@pytest.mark.parametrize("num_heads", NUM_HEADS) |
|
@pytest.mark.parametrize("head_size", HEAD_SIZES) |
|
@pytest.mark.parametrize("use_alibi", USE_ALIBI) |
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
def test_paged_attention( |
|
kv_cache_factory, |
|
version: str, |
|
num_seqs: int, |
|
num_heads: Tuple[int, int], |
|
head_size: int, |
|
use_alibi: bool, |
|
block_size: int, |
|
dtype: torch.dtype, |
|
kv_cache_dtype: str, |
|
seed: int, |
|
device: str, |
|
) -> None: |
|
if (kv_cache_dtype == "fp8" and head_size % 16) or ( |
|
version == "rocm" and head_size not in (64, 128) |
|
): |
|
pytest.skip() |
|
|
|
current_platform.seed_everything(seed) |
|
torch.set_default_device(device) |
|
scale = float(1.0 / (head_size**0.5)) |
|
num_query_heads, num_kv_heads = num_heads |
|
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) |
|
query.uniform_(-scale, scale) |
|
|
|
assert num_query_heads % num_kv_heads == 0 |
|
num_queries_per_kv = num_query_heads // num_kv_heads |
|
alibi_slopes = None |
|
if use_alibi: |
|
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) |
|
|
|
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] |
|
seq_lens[-1] = MAX_SEQ_LEN |
|
max_seq_len = max(seq_lens) |
|
seq_lens = torch.tensor(seq_lens, dtype=torch.int) |
|
|
|
|
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size |
|
block_tables_lst: List[List[int]] = [] |
|
for _ in range(num_seqs): |
|
block_table = [ |
|
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) |
|
] |
|
block_tables_lst.append(block_table) |
|
|
|
block_tables = torch.tensor(block_tables_lst, dtype=torch.int) |
|
|
|
|
|
key_caches, value_caches = kv_cache_factory( |
|
NUM_BLOCKS, |
|
block_size, |
|
1, |
|
num_kv_heads, |
|
head_size, |
|
kv_cache_dtype, |
|
dtype, |
|
seed, |
|
device, |
|
) |
|
key_cache, value_cache = key_caches[0], value_caches[0] |
|
|
|
|
|
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) |
|
|
|
|
|
output = torch.empty_like(query) |
|
if version == "v1": |
|
ops.paged_attention_v1( |
|
output, |
|
query, |
|
key_cache, |
|
value_cache, |
|
num_kv_heads, |
|
scale, |
|
block_tables, |
|
seq_lens, |
|
block_size, |
|
max_seq_len, |
|
alibi_slopes, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
) |
|
|
|
opcheck( |
|
ops.ops.paged_attention_v1, |
|
( |
|
output, |
|
query, |
|
key_cache, |
|
value_cache, |
|
num_kv_heads, |
|
scale, |
|
block_tables, |
|
seq_lens, |
|
block_size, |
|
max_seq_len, |
|
alibi_slopes, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
0, |
|
0, |
|
0, |
|
64, |
|
0, |
|
), |
|
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), |
|
) |
|
|
|
elif version in ("v2", "rocm"): |
|
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE |
|
assert PARTITION_SIZE % block_size == 0 |
|
num_seqs, num_heads, head_size = output.shape |
|
tmp_output = torch.empty( |
|
size=(num_seqs, num_heads, num_partitions, head_size), |
|
dtype=output.dtype, |
|
) |
|
exp_sums = torch.empty( |
|
size=(num_seqs, num_heads, num_partitions), |
|
dtype=torch.float32, |
|
) |
|
max_logits = torch.empty_like(exp_sums) |
|
if version == "v2": |
|
ops.paged_attention_v2( |
|
output, |
|
exp_sums, |
|
max_logits, |
|
tmp_output, |
|
query, |
|
key_cache, |
|
value_cache, |
|
num_kv_heads, |
|
scale, |
|
block_tables, |
|
seq_lens, |
|
block_size, |
|
max_seq_len, |
|
alibi_slopes, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
) |
|
|
|
opcheck( |
|
ops.ops.paged_attention_v2, |
|
( |
|
output, |
|
exp_sums, |
|
max_logits, |
|
tmp_output, |
|
query, |
|
key_cache, |
|
value_cache, |
|
num_kv_heads, |
|
scale, |
|
block_tables, |
|
seq_lens, |
|
block_size, |
|
max_seq_len, |
|
alibi_slopes, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
0, |
|
0, |
|
0, |
|
64, |
|
0, |
|
), |
|
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), |
|
) |
|
|
|
else: |
|
ops.paged_attention_rocm( |
|
output, |
|
exp_sums, |
|
max_logits, |
|
tmp_output, |
|
query, |
|
key_cache, |
|
value_cache, |
|
num_kv_heads, |
|
scale, |
|
block_tables, |
|
seq_lens, |
|
block_size, |
|
max_seq_len, |
|
alibi_slopes, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
) |
|
|
|
opcheck( |
|
torch.ops._rocm_C.paged_attention, |
|
( |
|
output, |
|
exp_sums, |
|
max_logits, |
|
tmp_output, |
|
query, |
|
key_cache, |
|
value_cache, |
|
num_kv_heads, |
|
scale, |
|
block_tables, |
|
seq_lens, |
|
block_size, |
|
max_seq_len, |
|
alibi_slopes, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
), |
|
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), |
|
) |
|
|
|
else: |
|
raise AssertionError(f"Unknown version: {version}") |
|
|
|
|
|
if kv_cache_dtype == "fp8": |
|
|
|
x = 16 // torch.tensor([], dtype=dtype).element_size() |
|
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) |
|
dequantized_key_cache = torch.empty( |
|
size=key_cache_shape, dtype=dtype, device=device |
|
) |
|
ops.convert_fp8(dequantized_key_cache, key_cache) |
|
key_cache = dequantized_key_cache |
|
|
|
value_cache_shape = value_cache.shape |
|
dequantized_value_cache = torch.empty( |
|
size=value_cache_shape, dtype=dtype, device=device |
|
) |
|
ops.convert_fp8(dequantized_value_cache, value_cache) |
|
value_cache = dequantized_value_cache |
|
|
|
ref_output = torch.empty_like(query) |
|
ref_single_query_cached_kv_attention( |
|
ref_output, |
|
query, |
|
num_queries_per_kv, |
|
key_cache, |
|
value_cache, |
|
block_tables, |
|
seq_lens, |
|
scale, |
|
alibi_slopes, |
|
) |
|
|
|
|
|
|
|
|
|
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 |
|
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 |
|
|
|
|
|
|
|
atol, rtol = 1e-3, 1e-5 |
|
if kv_cache_dtype == "fp8": |
|
atol, rtol = 1e-2, 1e-5 |
|
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) |
|
|
|
|
|
def ref_multi_query_kv_attention( |
|
cu_seq_lens: List[int], |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
scale: float, |
|
dtype: torch.dtype, |
|
) -> torch.Tensor: |
|
num_seqs = len(cu_seq_lens) - 1 |
|
ref_outputs: List[torch.Tensor] = [] |
|
for i in range(num_seqs): |
|
start_idx = cu_seq_lens[i] |
|
end_idx = cu_seq_lens[i + 1] |
|
seq_len = end_idx - start_idx |
|
|
|
|
|
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) |
|
attn_mask = attn_mask * torch.finfo(dtype).min |
|
attn_mask = attn_mask.to(dtype=dtype) |
|
|
|
ref_output = ref_masked_attention( |
|
query[start_idx:end_idx], |
|
key[start_idx:end_idx], |
|
value[start_idx:end_idx], |
|
scale, |
|
attn_mask=attn_mask, |
|
) |
|
ref_outputs.append(ref_output) |
|
|
|
return torch.cat(ref_outputs, dim=0) |
|
|