|
from typing import List, Optional, Tuple, Union |
|
|
|
import paged_attention as ops |
|
import pytest |
|
import torch |
|
|
|
|
|
@pytest.fixture() |
|
def kv_cache_factory(): |
|
return create_kv_caches_with_random |
|
|
|
|
|
@pytest.fixture() |
|
def kv_cache_factory_flashinfer(): |
|
return create_kv_caches_with_random_flash |
|
|
|
|
|
STR_DTYPE_TO_TORCH_DTYPE = { |
|
"half": torch.half, |
|
"bfloat16": torch.bfloat16, |
|
"float": torch.float, |
|
"fp8": torch.uint8, |
|
"fp8_e4m3": torch.uint8, |
|
"fp8_e5m2": torch.uint8, |
|
} |
|
|
|
|
|
def create_kv_caches_with_random( |
|
num_blocks: int, |
|
block_size: int, |
|
num_layers: int, |
|
num_heads: int, |
|
head_size: int, |
|
cache_dtype: Optional[Union[str, torch.dtype]], |
|
model_dtype: Optional[Union[str, torch.dtype]] = None, |
|
seed: int = 0, |
|
device: Optional[str] = "cuda", |
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
|
|
if cache_dtype == "fp8" and head_size % 16: |
|
raise ValueError( |
|
f"Does not support key cache of type fp8 with head_size {head_size}" |
|
) |
|
from paged_attention.platforms import current_platform |
|
|
|
current_platform.seed_everything(seed) |
|
|
|
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) |
|
|
|
scale = head_size**-0.5 |
|
x = 16 // torch.tensor([], dtype=torch_dtype).element_size() |
|
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) |
|
key_caches: List[torch.Tensor] = [] |
|
for _ in range(num_layers): |
|
key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) |
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]: |
|
key_cache.uniform_(-scale, scale) |
|
elif cache_dtype == "fp8": |
|
_generate_random_fp8(key_cache, -scale, scale) |
|
else: |
|
raise ValueError(f"Does not support key cache of type {cache_dtype}") |
|
key_caches.append(key_cache) |
|
|
|
value_cache_shape = (num_blocks, num_heads, head_size, block_size) |
|
value_caches: List[torch.Tensor] = [] |
|
for _ in range(num_layers): |
|
value_cache = torch.empty( |
|
size=value_cache_shape, dtype=torch_dtype, device=device |
|
) |
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]: |
|
value_cache.uniform_(-scale, scale) |
|
elif cache_dtype == "fp8": |
|
_generate_random_fp8(value_cache, -scale, scale) |
|
else: |
|
raise ValueError(f"Does not support value cache of type {cache_dtype}") |
|
value_caches.append(value_cache) |
|
return key_caches, value_caches |
|
|
|
|
|
def create_kv_caches_with_random_flash( |
|
num_blocks: int, |
|
block_size: int, |
|
num_layers: int, |
|
num_heads: int, |
|
head_size: int, |
|
cache_dtype: Optional[Union[str, torch.dtype]], |
|
model_dtype: Optional[Union[str, torch.dtype]] = None, |
|
seed: int = 0, |
|
device: Optional[str] = "cuda", |
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
from paged_attention.platforms import current_platform |
|
|
|
current_platform.seed_everything(seed) |
|
|
|
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) |
|
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) |
|
scale = head_size**-0.5 |
|
|
|
key_caches: List[torch.Tensor] = [] |
|
value_caches: List[torch.Tensor] = [] |
|
|
|
for _ in range(num_layers): |
|
key_value_cache = torch.empty( |
|
size=key_value_cache_shape, dtype=torch_dtype, device=device |
|
) |
|
if cache_dtype in ["auto", "half", "bfloat16", "float"]: |
|
key_value_cache.uniform_(-scale, scale) |
|
elif cache_dtype == "fp8": |
|
_generate_random_fp8(key_value_cache, -scale, scale) |
|
else: |
|
raise ValueError(f"Does not support key cache of type {cache_dtype}") |
|
key_caches.append(key_value_cache[:, 0]) |
|
value_caches.append(key_value_cache[:, 1]) |
|
return key_caches, value_caches |
|
|
|
|
|
def get_kv_cache_torch_dtype( |
|
cache_dtype: Optional[Union[str, torch.dtype]], |
|
model_dtype: Optional[Union[str, torch.dtype]] = None, |
|
) -> torch.dtype: |
|
if isinstance(cache_dtype, str): |
|
if cache_dtype == "auto": |
|
if isinstance(model_dtype, str): |
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] |
|
elif isinstance(model_dtype, torch.dtype): |
|
torch_dtype = model_dtype |
|
else: |
|
raise ValueError(f"Invalid model dtype: {model_dtype}") |
|
elif cache_dtype in ["half", "bfloat16", "float"]: |
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] |
|
elif cache_dtype == "fp8": |
|
torch_dtype = torch.uint8 |
|
else: |
|
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") |
|
elif isinstance(cache_dtype, torch.dtype): |
|
torch_dtype = cache_dtype |
|
else: |
|
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") |
|
return torch_dtype |
|
|
|
|
|
def _generate_random_fp8( |
|
tensor: torch.Tensor, |
|
low: float, |
|
high: float, |
|
) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) |
|
tensor_tmp.uniform_(low, high) |
|
ops.convert_fp8(tensor, tensor_tmp) |
|
del tensor_tmp |
|
|