Support FP8-E5M2 KV Cache (#2279)
Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -6,14 +6,16 @@ import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm._C import ops, cache_ops
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
NUM_BLOCKS = 12000 # Arbitrary values for testing
|
||||
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
||||
# Reduce NUM_BLOCKS when it happens.
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@@ -23,6 +25,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
@@ -105,6 +108,7 @@ def ref_single_query_cached_kv_attention(
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_paged_attention(
|
||||
@@ -116,6 +120,7 @@ def test_paged_attention(
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
@@ -158,8 +163,9 @@ def test_paged_attention(
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size, dtype,
|
||||
seed, gpu_id)
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
gpu_id)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Call the paged attention kernel.
|
||||
@@ -177,6 +183,7 @@ def test_paged_attention(
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
elif version == "v2":
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||
@@ -209,11 +216,30 @@ def test_paged_attention(
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
if kv_cache_dtype == "fp8_e5m2":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
@@ -230,7 +256,12 @@ def test_paged_attention(
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8_e5m2":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
|
||||
Reference in New Issue
Block a user