Support FP8-E5M2 KV Cache (#2279)
Co-authored-by: zhaoyang <zhao.yang16@zte.com.cn> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@@ -26,6 +27,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks(
|
||||
kv_cache_factory,
|
||||
@@ -38,6 +40,7 @@ def test_copy_blocks(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
@@ -59,7 +62,8 @@ def test_copy_blocks(
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, dtype, seed, gpu_id)
|
||||
head_size, kv_cache_dtype,
|
||||
dtype, seed, gpu_id)
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
@@ -124,7 +128,7 @@ def test_reshape_and_cache(
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size, dtype,
|
||||
seed, gpu_id)
|
||||
None, seed, gpu_id)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Clone the KV caches.
|
||||
@@ -133,7 +137,7 @@ def test_reshape_and_cache(
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping)
|
||||
slot_mapping, "auto")
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
|
||||
Reference in New Issue
Block a user