Support FP8-E5M2 KV Cache (#2279)

Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
zhaoyang-star
2024-01-29 08:43:54 +08:00
committed by GitHub
parent 7d648418b8
commit 9090bf02e7
26 changed files with 912 additions and 196 deletions

View File

@@ -6,14 +6,16 @@ import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm._C import ops
from vllm._C import ops, cache_ops
from vllm.utils import get_max_shared_memory_bytes
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 12000 # Arbitrary values for testing
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -23,6 +25,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@@ -105,6 +108,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
def test_paged_attention(
@@ -116,6 +120,7 @@ def test_paged_attention(
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: int,
) -> None:
@@ -158,8 +163,9 @@ def test_paged_attention(
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size, dtype,
seed, gpu_id)
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0]
# Call the paged attention kernel.
@@ -177,6 +183,7 @@ def test_paged_attention(
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
elif version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
@@ -209,11 +216,30 @@ def test_paged_attention(
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
raise AssertionError(f"Unknown version: {version}")
# Run the reference implementation.
if kv_cache_dtype == "fp8_e5m2":
# Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=gpu_id)
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=gpu_id)
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
@@ -230,7 +256,12 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8_e5m2":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
def ref_multi_query_kv_attention(