danieldk's picture
danieldk HF staff
Rename to paged-attention
3dcba92
raw
history blame
5.48 kB
from typing import List, Optional, Tuple, Union
import paged_attention as ops
import pytest
import torch
@pytest.fixture()
def kv_cache_factory():
return create_kv_caches_with_random
@pytest.fixture()
def kv_cache_factory_flashinfer():
return create_kv_caches_with_random_flash
STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
}
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16:
raise ValueError(
f"Does not support key cache of type fp8 with head_size {head_size}"
)
from paged_attention.platforms import current_platform
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches: List[torch.Tensor] = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":
_generate_random_fp8(key_cache, -scale, scale)
else:
raise ValueError(f"Does not support key cache of type {cache_dtype}")
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches: List[torch.Tensor] = []
for _ in range(num_layers):
value_cache = torch.empty(
size=value_cache_shape, dtype=torch_dtype, device=device
)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
value_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":
_generate_random_fp8(value_cache, -scale, scale)
else:
raise ValueError(f"Does not support value cache of type {cache_dtype}")
value_caches.append(value_cache)
return key_caches, value_caches
def create_kv_caches_with_random_flash(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
from paged_attention.platforms import current_platform
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
scale = head_size**-0.5
key_caches: List[torch.Tensor] = []
value_caches: List[torch.Tensor] = []
for _ in range(num_layers):
key_value_cache = torch.empty(
size=key_value_cache_shape, dtype=torch_dtype, device=device
)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":
_generate_random_fp8(key_value_cache, -scale, scale)
else:
raise ValueError(f"Does not support key cache of type {cache_dtype}")
key_caches.append(key_value_cache[:, 0])
value_caches.append(key_value_cache[:, 1])
return key_caches, value_caches
def get_kv_cache_torch_dtype(
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
) -> torch.dtype:
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
if isinstance(model_dtype, str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype
else:
raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8":
torch_dtype = torch.uint8
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype):
torch_dtype = cache_dtype
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
return torch_dtype
def _generate_random_fp8(
tensor: torch.Tensor,
low: float,
high: float,
) -> None:
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
# it may occur Inf or NaN if we directly use torch.randint
# to generate random data for fp8 data.
# For example, s.11111.00 in fp8e5m2 format represents Inf.
# | E4M3 | E5M2
# -----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high)
ops.convert_fp8(tensor, tensor_tmp)
del tensor_tmp