[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)
This commit is contained in:
@@ -191,6 +191,16 @@ def test_sparse_backend_decode_correctness(
|
||||
if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes:
|
||||
pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}")
|
||||
|
||||
if (
|
||||
backend_cls == FlashMLASparseBackend
|
||||
and kv_cache_dtype.startswith("fp8")
|
||||
and kv_cache_dtype != "fp8_ds_mla"
|
||||
):
|
||||
pytest.skip(
|
||||
"FlashMLA Sparse Attention backend fp8 only supports "
|
||||
"fp8_ds_mla kv-cache dtype"
|
||||
)
|
||||
|
||||
supported_block_sizes = backend_cls.get_supported_kernel_block_sizes()
|
||||
if block_size not in supported_block_sizes:
|
||||
pytest.skip(
|
||||
@@ -419,7 +429,7 @@ def test_sparse_backend_decode_correctness(
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=False,
|
||||
kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto",
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user