[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)

This commit is contained in:
Wei Zhao
2026-03-07 16:51:54 -05:00
committed by GitHub
parent a6be75dbd2
commit 379689d533
8 changed files with 89 additions and 17 deletions

View File

@@ -331,11 +331,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
calculate_kv_scales = False
self.quant_config = quant_config
# Initialize KV cache quantization attributes
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
@@ -347,6 +342,36 @@ class MLAAttention(nn.Module, AttentionLayerBase):
num_heads=self.num_heads,
)
# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if (
self.attn_backend.get_name() == "FLASHMLA_SPARSE"
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"
):
assert cache_config is not None
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
logger.info_once(
"Using DeepSeek's fp8_ds_mla KV cache format. To use standard "
"fp8 kv-cache format, please set `--attention-backend "
"FLASHINFER_MLA_SPARSE`"
)
if (
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
and kv_cache_dtype.startswith("fp8")
):
logger.info_once(
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
"KV cache format, please set `--attention-backend FLASHMLA_SPARSE`"
)
# Initialize KV cache quantization attributes
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)
if (
cache_config is not None
and cache_config.enable_prefix_caching