[Attention] FA4 integration (#32974)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2026-03-01 18:44:57 -05:00
committed by GitHub
parent 57a96e26c9
commit 8b5014d3dd
15 changed files with 818 additions and 55 deletions

View File

@@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
# ---------------------------------------------------------------------------
# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill)
# Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill)
# ---------------------------------------------------------------------------
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences.
def _parse_fa4_supported_caps() -> str | None:
"""Parse flash_attn_interface.py for FA4 supported compute capabilities.
Returns a dict with 'fa2' and 'fa3' keys containing their respective
Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported().
"""
fa_interface_file = (
REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py"
)
if not fa_interface_file.exists():
return None
try:
tree = ast.parse(fa_interface_file.read_text())
except Exception:
return None
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported":
continue
for n in ast.walk(node):
if not (
isinstance(n, ast.Compare)
and len(n.ops) == 1
and isinstance(n.ops[0], ast.NotIn)
and isinstance(n.comparators[0], ast.List)
):
continue
caps: list[int] = [
e.value
for e in n.comparators[0].elts
if isinstance(e, ast.Constant) and isinstance(e.value, int)
]
if caps:
caps.sort()
return f"{caps[0]}.x-{caps[-1]}.x"
return None
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"""Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences.
Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective
feature overrides for compute capability, KV cache dtypes, and sink support.
"""
if not FA_UTILS_FILE.exists():
@@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_fp8 = False
fa3_supports_sinks = False
fa3_compute_cap: str | None = None
fa4_compute_cap: str | None = None
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef):
@@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_sinks = True
break
# Check get_flash_attn_version for FA3 compute capability
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
# Check get_flash_attn_version for FA3/FA4 compute capability
if node.name == "get_flash_attn_version":
for n in ast.walk(node):
# Look for IfExp (ternary) with `device_capability.major == 9`
# Handle IfExp (ternary) with `device_capability.major == 9`
if isinstance(n, ast.IfExp):
test = n.test
# Check if test is a BoolOp (and) containing the major check
if isinstance(test, ast.BoolOp):
for val in test.values:
if (
@@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_compute_cap = f"{val.comparators[0].value}.x"
break
# Handle If statements for FA3/FA4 detection
# e.g. `if device_capability.major == 9` -> FA3
# `elif device_capability.major >= 10` -> FA4
if isinstance(n, ast.If):
test = n.test
comparisons = (
[v for v in test.values if isinstance(v, ast.Compare)]
if isinstance(test, ast.BoolOp)
else [test]
if isinstance(test, ast.Compare)
else []
)
for comp in comparisons:
if not (
isinstance(comp.left, ast.Attribute)
and comp.left.attr == "major"
and comp.comparators
and isinstance(comp.comparators[0], ast.Constant)
and isinstance(comp.comparators[0].value, int)
):
continue
op = comp.ops[0]
val = comp.comparators[0].value
if isinstance(op, ast.Eq) and fa3_compute_cap is None:
fa3_compute_cap = f"{val}.x"
elif isinstance(op, ast.GtE) and fa4_compute_cap is None:
fa4_compute_cap = f"{val}.0"
# Fallback: try to parse FA4 compute caps from flash_attn_interface.py
if fa4_compute_cap is None:
fa4_compute_cap = _parse_fa4_supported_caps()
return {
"fa2": {
"supports_fp8": False,
@@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"supports_fp8": fa3_supports_fp8,
"supports_sink": fa3_supports_sinks,
},
"fa4": {
"compute_capability": fa4_compute_cap,
"supports_fp8": False,
"supports_sink": False,
},
}
@@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]:
# ---------------------------------------------------------------------------
# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM)
# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
# ---------------------------------------------------------------------------
@@ -768,7 +843,7 @@ def _expand_flash_attn_variants(
all_backends: list[dict[str, Any]],
fa_features: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
"""Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities."""
"""Expand FLASH_ATTN into FA2, FA3, and FA4 variants."""
expanded = []
for backend in all_backends:
if backend["name"] != "FLASH_ATTN":
@@ -801,6 +876,18 @@ def _expand_flash_attn_variants(
expanded.append(fa2)
expanded.append(fa3)
# Create FA4 entry if FA4 features are available
if "fa4" in fa_features:
fa4 = backend.copy()
fa4["version"] = "FA4*"
fa4["_sort_key"] = "FLASH_ATTN"
fa4["_sort_order"] = 2
if fa_features["fa4"].get("compute_capability"):
fa4["compute_capability"] = fa_features["fa4"]["compute_capability"]
fa4["supports_sink"] = fa_features["fa4"]["supports_sink"]
expanded.append(fa4)
return expanded
@@ -1360,7 +1447,8 @@ def generate_docs() -> str:
if fa_features:
footnotes.append(
"> **\\*** Specify the FlashAttention version via "
"`--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, "
"`--attention-config.flash_attn_version=2`, `3`, or `4`. "
"Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), "
"FA2 otherwise."
)
if footnotes: