[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)
This commit is contained in:
@@ -63,6 +63,8 @@ class FlashInferMLASparseBackend(AttentionBackend):
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -304,6 +306,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
# fp8 query quantization is required when using fp8 kv_cache,
|
||||
# as the TRTLLM-GEN sparse MLA kernel requires matching dtypes
|
||||
# for query and kv_cache (mixed bf16+fp8 is not supported).
|
||||
self.supports_quant_query_input = True
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
@@ -83,6 +83,7 @@ class FlashMLASparseBackend(AttentionBackend):
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8_ds_mla",
|
||||
"fp8", # alias for fp8_ds_mla
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -567,6 +568,12 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
)
|
||||
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
assert kv_cache_dtype == "fp8_ds_mla", (
|
||||
"FlashMLA Sparse Attention backend fp8 only supports "
|
||||
"fp8_ds_mla kv-cache dtype"
|
||||
)
|
||||
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# Reserve workspace during initialization
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
Reference in New Issue
Block a user