Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: HaiShaw <hixiao@gmail.com> Co-authored-by: AdrianAbeyta <Adrian.Abeyta@amd.com> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: root <root@gt-pla-u18-08.pla.dcgpu> Co-authored-by: mawong-amd <156021403+mawong-amd@users.noreply.github.com> Co-authored-by: ttbachyinsda <ttbachyinsda@outlook.com> Co-authored-by: guofangze <guofangze@kuaishou.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -5,6 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm._C import cache_ops
|
||||
from vllm.utils import is_hip
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@@ -23,7 +24,7 @@ SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@@ -105,6 +106,7 @@ def test_copy_blocks(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache(
|
||||
kv_cache_factory,
|
||||
@@ -116,7 +118,10 @@ def test_reshape_and_cache(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
if not is_hip() and kv_cache_dtype == "fp8":
|
||||
pytest.skip() # This test is not tuned for e5m2 cuda precision
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
@@ -132,17 +137,33 @@ def test_reshape_and_cache(
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size, dtype,
|
||||
None, seed, device)
|
||||
num_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(key_cache, cloned_key_cache)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(value_cache, cloned_value_cache)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Using default kv_scale
|
||||
kv_scale = 1.0
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping, "auto")
|
||||
slot_mapping, kv_cache_dtype, kv_scale)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(key_cache, result_key_cache)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(value_cache, result_value_cache)
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
@@ -156,8 +177,18 @@ def test_reshape_and_cache(
|
||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
if kv_cache_dtype == "fp8":
|
||||
assert torch.allclose(result_key_cache,
|
||||
cloned_key_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
assert torch.allclose(result_value_cache,
|
||||
cloned_value_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
else:
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
||||
@@ -169,6 +200,7 @@ def test_reshape_and_cache(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_swap_blocks(
|
||||
kv_cache_factory,
|
||||
@@ -181,7 +213,12 @@ def test_swap_blocks(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and "cpu" in direction:
|
||||
pytest.skip()
|
||||
if not is_hip() and kv_cache_dtype == "fp8":
|
||||
pytest.skip() # This test is not tuned for e5m2 cuda precision
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
@@ -202,13 +239,13 @@ def test_swap_blocks(
|
||||
|
||||
# Create the KV caches on the first device.
|
||||
src_key_caches, src_value_caches = kv_cache_factory(
|
||||
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
|
||||
src_device)
|
||||
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
|
||||
seed, src_device)
|
||||
|
||||
# Create the KV caches on the second device.
|
||||
dist_key_caches, dist_value_caches = kv_cache_factory(
|
||||
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
|
||||
dst_device)
|
||||
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
|
||||
seed, dst_device)
|
||||
|
||||
src_key_caches_clone = src_key_caches[0].clone()
|
||||
src_value_caches_clone = src_value_caches[0].clone()
|
||||
@@ -223,3 +260,40 @@ def test_swap_blocks(
|
||||
dist_key_caches[0][dst].cpu())
|
||||
assert torch.allclose(src_value_caches_clone[src].cpu(),
|
||||
dist_value_caches[0][dst].cpu())
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_fp8_conversion(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
low = -224.0
|
||||
high = 224.0
|
||||
shape = (num_blocks, num_heads, head_size, block_size)
|
||||
cache = torch.empty(shape, dtype=dtype, device=device)
|
||||
cache.uniform_(low, high)
|
||||
|
||||
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
||||
cache_ops.convert_fp8(cache, cache_fp8)
|
||||
|
||||
converted_cache = torch.empty_like(cache)
|
||||
cache_ops.convert_fp8(cache_fp8, converted_cache)
|
||||
|
||||
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||
|
||||
Reference in New Issue
Block a user