[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)

This commit is contained in:
Wei Zhao
2026-03-07 16:51:54 -05:00
committed by GitHub
parent a6be75dbd2
commit 379689d533
8 changed files with 89 additions and 17 deletions

View File

@@ -49,6 +49,11 @@ MLA_ATTENTION_FILE = (
# Backends to skip during doc generation
SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"}
BACKEND_KV_DTYPE_EXCLUDES: dict[str, set[str]] = {
# fp8 is an alias for fp8_ds_mla for FlashMLA Sparse
"FLASHMLA_SPARSE": {"fp8"},
}
def is_relevant_file(filepath: str) -> bool:
"""Check if a file matches any of the relevant patterns."""
@@ -546,10 +551,19 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
tree, impl_class_name, "can_return_lse_for_decode", False, file_path
)
kv_cache_dtypes = parse_kv_cache_dtypes(class_node)
if backend_name in BACKEND_KV_DTYPE_EXCLUDES:
excluded = BACKEND_KV_DTYPE_EXCLUDES[backend_name]
kv_cache_dtypes = ", ".join(
d
for d in (d.strip() for d in kv_cache_dtypes.split(","))
if d not in excluded
)
return {
"name": backend_name,
"dtypes": parse_supported_dtypes(class_node),
"kv_cache_dtypes": parse_kv_cache_dtypes(class_node),
"kv_cache_dtypes": kv_cache_dtypes,
"block_sizes": parse_block_sizes(class_node),
"head_sizes": parse_head_sizes(class_node),
"attn_types": parse_attention_types(class_node),