Support FP8-E5M2 KV Cache (#2279)

Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
zhaoyang-star
2024-01-29 08:43:54 +08:00
committed by GitHub
parent 7d648418b8
commit 9090bf02e7
26 changed files with 912 additions and 196 deletions

View File

@@ -15,6 +15,7 @@ NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@@ -26,6 +27,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks(
kv_cache_factory,
@@ -38,6 +40,7 @@ def test_copy_blocks(
dtype: torch.dtype,
seed: int,
device: int,
kv_cache_dtype: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
@@ -59,7 +62,8 @@ def test_copy_blocks(
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
num_layers, num_heads,
head_size, dtype, seed, gpu_id)
head_size, kv_cache_dtype,
dtype, seed, gpu_id)
# Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
@@ -124,7 +128,7 @@ def test_reshape_and_cache(
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size, dtype,
seed, gpu_id)
None, seed, gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches.
@@ -133,7 +137,7 @@ def test_reshape_and_cache(
# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping)
slot_mapping, "auto")
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)