[Attention] FA4 integration (#32974)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill)
|
||||
# Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences.
|
||||
def _parse_fa4_supported_caps() -> str | None:
|
||||
"""Parse flash_attn_interface.py for FA4 supported compute capabilities.
|
||||
|
||||
Returns a dict with 'fa2' and 'fa3' keys containing their respective
|
||||
Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported().
|
||||
"""
|
||||
fa_interface_file = (
|
||||
REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py"
|
||||
)
|
||||
if not fa_interface_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
tree = ast.parse(fa_interface_file.read_text())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported":
|
||||
continue
|
||||
for n in ast.walk(node):
|
||||
if not (
|
||||
isinstance(n, ast.Compare)
|
||||
and len(n.ops) == 1
|
||||
and isinstance(n.ops[0], ast.NotIn)
|
||||
and isinstance(n.comparators[0], ast.List)
|
||||
):
|
||||
continue
|
||||
caps: list[int] = [
|
||||
e.value
|
||||
for e in n.comparators[0].elts
|
||||
if isinstance(e, ast.Constant) and isinstance(e.value, int)
|
||||
]
|
||||
if caps:
|
||||
caps.sort()
|
||||
return f"{caps[0]}.x-{caps[-1]}.x"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
"""Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences.
|
||||
|
||||
Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective
|
||||
feature overrides for compute capability, KV cache dtypes, and sink support.
|
||||
"""
|
||||
if not FA_UTILS_FILE.exists():
|
||||
@@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
fa3_supports_fp8 = False
|
||||
fa3_supports_sinks = False
|
||||
fa3_compute_cap: str | None = None
|
||||
fa4_compute_cap: str | None = None
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef):
|
||||
@@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
fa3_supports_sinks = True
|
||||
break
|
||||
|
||||
# Check get_flash_attn_version for FA3 compute capability
|
||||
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
|
||||
# Check get_flash_attn_version for FA3/FA4 compute capability
|
||||
if node.name == "get_flash_attn_version":
|
||||
for n in ast.walk(node):
|
||||
# Look for IfExp (ternary) with `device_capability.major == 9`
|
||||
# Handle IfExp (ternary) with `device_capability.major == 9`
|
||||
if isinstance(n, ast.IfExp):
|
||||
test = n.test
|
||||
# Check if test is a BoolOp (and) containing the major check
|
||||
if isinstance(test, ast.BoolOp):
|
||||
for val in test.values:
|
||||
if (
|
||||
@@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
fa3_compute_cap = f"{val.comparators[0].value}.x"
|
||||
break
|
||||
|
||||
# Handle If statements for FA3/FA4 detection
|
||||
# e.g. `if device_capability.major == 9` -> FA3
|
||||
# `elif device_capability.major >= 10` -> FA4
|
||||
if isinstance(n, ast.If):
|
||||
test = n.test
|
||||
comparisons = (
|
||||
[v for v in test.values if isinstance(v, ast.Compare)]
|
||||
if isinstance(test, ast.BoolOp)
|
||||
else [test]
|
||||
if isinstance(test, ast.Compare)
|
||||
else []
|
||||
)
|
||||
for comp in comparisons:
|
||||
if not (
|
||||
isinstance(comp.left, ast.Attribute)
|
||||
and comp.left.attr == "major"
|
||||
and comp.comparators
|
||||
and isinstance(comp.comparators[0], ast.Constant)
|
||||
and isinstance(comp.comparators[0].value, int)
|
||||
):
|
||||
continue
|
||||
op = comp.ops[0]
|
||||
val = comp.comparators[0].value
|
||||
if isinstance(op, ast.Eq) and fa3_compute_cap is None:
|
||||
fa3_compute_cap = f"{val}.x"
|
||||
elif isinstance(op, ast.GtE) and fa4_compute_cap is None:
|
||||
fa4_compute_cap = f"≥{val}.0"
|
||||
|
||||
# Fallback: try to parse FA4 compute caps from flash_attn_interface.py
|
||||
if fa4_compute_cap is None:
|
||||
fa4_compute_cap = _parse_fa4_supported_caps()
|
||||
|
||||
return {
|
||||
"fa2": {
|
||||
"supports_fp8": False,
|
||||
@@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
"supports_fp8": fa3_supports_fp8,
|
||||
"supports_sink": fa3_supports_sinks,
|
||||
},
|
||||
"fa4": {
|
||||
"compute_capability": fa4_compute_cap,
|
||||
"supports_fp8": False,
|
||||
"supports_sink": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM)
|
||||
# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -768,7 +843,7 @@ def _expand_flash_attn_variants(
|
||||
all_backends: list[dict[str, Any]],
|
||||
fa_features: dict[str, dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities."""
|
||||
"""Expand FLASH_ATTN into FA2, FA3, and FA4 variants."""
|
||||
expanded = []
|
||||
for backend in all_backends:
|
||||
if backend["name"] != "FLASH_ATTN":
|
||||
@@ -801,6 +876,18 @@ def _expand_flash_attn_variants(
|
||||
|
||||
expanded.append(fa2)
|
||||
expanded.append(fa3)
|
||||
|
||||
# Create FA4 entry if FA4 features are available
|
||||
if "fa4" in fa_features:
|
||||
fa4 = backend.copy()
|
||||
fa4["version"] = "FA4*"
|
||||
fa4["_sort_key"] = "FLASH_ATTN"
|
||||
fa4["_sort_order"] = 2
|
||||
if fa_features["fa4"].get("compute_capability"):
|
||||
fa4["compute_capability"] = fa_features["fa4"]["compute_capability"]
|
||||
fa4["supports_sink"] = fa_features["fa4"]["supports_sink"]
|
||||
expanded.append(fa4)
|
||||
|
||||
return expanded
|
||||
|
||||
|
||||
@@ -1360,7 +1447,8 @@ def generate_docs() -> str:
|
||||
if fa_features:
|
||||
footnotes.append(
|
||||
"> **\\*** Specify the FlashAttention version via "
|
||||
"`--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, "
|
||||
"`--attention-config.flash_attn_version=2`, `3`, or `4`. "
|
||||
"Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), "
|
||||
"FA2 otherwise."
|
||||
)
|
||||
if footnotes:
|
||||
|
||||
Reference in New Issue
Block a user