[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)

This commit is contained in:
Wei Zhao
2026-03-07 16:51:54 -05:00
committed by GitHub
parent a6be75dbd2
commit 379689d533
8 changed files with 89 additions and 17 deletions

View File

@@ -63,6 +63,8 @@ class FlashInferMLASparseBackend(AttentionBackend):
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]
@staticmethod
@@ -304,6 +306,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
# fp8 query quantization is required when using fp8 kv_cache,
# as the TRTLLM-GEN sparse MLA kernel requires matching dtypes
# for query and kv_cache (mixed bf16+fp8 is not supported).
self.supports_quant_query_input = True
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],

View File

@@ -83,6 +83,7 @@ class FlashMLASparseBackend(AttentionBackend):
"auto",
"bfloat16",
"fp8_ds_mla",
"fp8", # alias for fp8_ds_mla
]
@staticmethod
@@ -567,6 +568,12 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
)
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
if kv_cache_dtype.startswith("fp8"):
assert kv_cache_dtype == "fp8_ds_mla", (
"FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype"
)
if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
vllm_config = get_current_vllm_config()