[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535)

This commit is contained in:
Cody Yu
2024-05-09 17:04:17 -07:00
committed by GitHub
parent 379da6dcb5
commit c833101740
17 changed files with 843 additions and 558 deletions

View File

@@ -5,8 +5,6 @@ import pytest
import torch
from vllm import _custom_ops as ops
from vllm._C import cache_ops
from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -25,6 +23,8 @@ SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]
@@ -124,8 +124,6 @@ def test_reshape_and_cache(
device: str,
kv_cache_dtype: str,
) -> None:
if not is_hip() and kv_cache_dtype == "fp8":
pytest.skip() # This test is not tuned for e5m2 cuda precision
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
@@ -149,9 +147,9 @@ def test_reshape_and_cache(
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(key_cache, cloned_key_cache)
ops.convert_fp8(cloned_key_cache, key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(value_cache, cloned_value_cache)
ops.convert_fp8(cloned_value_cache, value_cache)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
@@ -165,9 +163,9 @@ def test_reshape_and_cache(
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(key_cache, result_key_cache)
ops.convert_fp8(result_key_cache, key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(value_cache, result_value_cache)
ops.convert_fp8(result_value_cache, value_cache)
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
@@ -255,8 +253,8 @@ def test_reshape_and_cache_flash(
cloned_value_cache = value_cache.clone()
# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
@@ -299,8 +297,6 @@ def test_swap_blocks(
) -> None:
if kv_cache_dtype == "fp8" and "cpu" in direction:
pytest.skip()
if not is_hip() and kv_cache_dtype == "fp8":
pytest.skip() # This test is not tuned for e5m2 cuda precision
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
@@ -348,7 +344,6 @@ def test_swap_blocks(
dist_value_caches[0][dst].cpu())
@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@@ -357,7 +352,7 @@ def test_swap_blocks(
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_conversion(
def test_fp8_e4m3_conversion(
num_heads: int,
head_size: int,
block_size: int,
@@ -377,9 +372,9 @@ def test_fp8_conversion(
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
ops.convert_fp8(cache, cache_fp8)
ops.convert_fp8(cache_fp8, cache)
converted_cache = torch.empty_like(cache)
ops.convert_fp8(cache_fp8, converted_cache)
ops.convert_fp8(converted_cache, cache_fp8)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)