[Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)
This commit is contained in:
@@ -49,6 +49,11 @@ MLA_ATTENTION_FILE = (
|
||||
# Backends to skip during doc generation
|
||||
SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"}
|
||||
|
||||
BACKEND_KV_DTYPE_EXCLUDES: dict[str, set[str]] = {
|
||||
# fp8 is an alias for fp8_ds_mla for FlashMLA Sparse
|
||||
"FLASHMLA_SPARSE": {"fp8"},
|
||||
}
|
||||
|
||||
|
||||
def is_relevant_file(filepath: str) -> bool:
|
||||
"""Check if a file matches any of the relevant patterns."""
|
||||
@@ -546,10 +551,19 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
|
||||
tree, impl_class_name, "can_return_lse_for_decode", False, file_path
|
||||
)
|
||||
|
||||
kv_cache_dtypes = parse_kv_cache_dtypes(class_node)
|
||||
if backend_name in BACKEND_KV_DTYPE_EXCLUDES:
|
||||
excluded = BACKEND_KV_DTYPE_EXCLUDES[backend_name]
|
||||
kv_cache_dtypes = ", ".join(
|
||||
d
|
||||
for d in (d.strip() for d in kv_cache_dtypes.split(","))
|
||||
if d not in excluded
|
||||
)
|
||||
|
||||
return {
|
||||
"name": backend_name,
|
||||
"dtypes": parse_supported_dtypes(class_node),
|
||||
"kv_cache_dtypes": parse_kv_cache_dtypes(class_node),
|
||||
"kv_cache_dtypes": kv_cache_dtypes,
|
||||
"block_sizes": parse_block_sizes(class_node),
|
||||
"head_sizes": parse_head_sizes(class_node),
|
||||
"attn_types": parse_attention_types(class_node),
|
||||
|
||||
Reference in New Issue
Block a user