[Perf] Set Flashinfer sparse MLA as default backend for FP8 kv cache (#37252)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
This commit is contained in:
Wei Zhao
2026-03-17 16:09:20 -04:00
committed by GitHub
parent e78821b438
commit b36adfa349
3 changed files with 61 additions and 20 deletions

View File

@@ -1262,14 +1262,23 @@ When no backend is specified (the default):
"""
def _priority_table(title: str, backends: list[str]) -> list[str]:
def _priority_table(
title: str,
backends: list[str],
annotations: dict[str, str] | None = None,
) -> list[str]:
"""Generate a priority table for a list of backends."""
def _fmt(b: str) -> str:
suffix = annotations.get(b, "") if annotations else ""
return f"`{b}`{suffix}"
return [
f"**{title}:**",
"",
"| Priority | Backend |",
"| -------- | ------- |",
*[f"| {i} | `{b}` |" for i, b in enumerate(backends, 1)],
*[f"| {i} | {_fmt(b)} |" for i, b in enumerate(backends, 1)],
"",
]
@@ -1298,11 +1307,25 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str:
lines.extend(["### MLA Attention (DeepSeek-style)", ""])
mla_sm100_annotations = {
"FLASHINFER_MLA_SPARSE": "**\\***",
}
if "mla_sm100" in priorities:
lines.extend(_priority_table(sm100, priorities["mla_sm100"]))
lines.extend(
_priority_table(sm100, priorities["mla_sm100"], mla_sm100_annotations)
)
if "mla_default" in priorities:
lines.extend(_priority_table(ampere, priorities["mla_default"]))
if "mla_sm100" in priorities:
lines.append(
"> **\\*** For sparse MLA, FP8 KV cache always prefers "
"`FLASHINFER_MLA_SPARSE`. With BF16 KV cache, `FLASHINFER_MLA_SPARSE` "
"is preferred for low query-head counts (<= 16), while "
"`FLASHMLA_SPARSE` is preferred otherwise."
)
lines.append(">")
lines.append(
"> **Note:** ROCm and CPU platforms have their own selection logic. "
"See the platform-specific documentation for details."