2026-01-28 17:20:22 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
"""
|
|
|
|
|
Generates documentation table for attention backends showing feature support.
|
|
|
|
|
|
|
|
|
|
This script parses all registered attention backends using AST (no imports needed)
|
|
|
|
|
and generates a markdown table showing what features each backend supports,
|
|
|
|
|
based on the checks in AttentionBackend.validate_configuration().
|
|
|
|
|
|
|
|
|
|
This approach avoids requiring CUDA/ROCm/GPU libraries to be installed.
|
|
|
|
|
|
|
|
|
|
When used as a pre-commit hook, this script receives filenames as arguments
|
|
|
|
|
and only runs the check if any of the relevant files were modified.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import ast
|
|
|
|
|
import fnmatch
|
|
|
|
|
import sys
|
2026-02-09 18:33:43 -05:00
|
|
|
from collections.abc import Callable
|
2026-01-28 17:20:22 -05:00
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Constants and file paths
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
2026-01-28 17:20:22 -05:00
|
|
|
REPO_ROOT = Path(__file__).parent.parent.parent
|
|
|
|
|
|
|
|
|
|
RELEVANT_PATTERNS = [
|
|
|
|
|
"vllm/v1/attention/backends/*.py",
|
|
|
|
|
"vllm/v1/attention/backends/**/*.py",
|
|
|
|
|
"vllm/v1/attention/backends/fa_utils.py",
|
|
|
|
|
"vllm/model_executor/layers/attention/mla_attention.py",
|
|
|
|
|
"vllm/platforms/cuda.py",
|
|
|
|
|
"tools/pre_commit/generate_attention_backend_docs.py",
|
|
|
|
|
"docs/design/attention_backends.md",
|
|
|
|
|
]
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
BACKENDS_DIR = REPO_ROOT / "vllm" / "v1" / "attention" / "backends"
|
|
|
|
|
REGISTRY_FILE = BACKENDS_DIR / "registry.py"
|
|
|
|
|
CUDA_PLATFORM_FILE = REPO_ROOT / "vllm" / "platforms" / "cuda.py"
|
|
|
|
|
FA_UTILS_FILE = BACKENDS_DIR / "fa_utils.py"
|
|
|
|
|
FLASHINFER_UTILS_FILE = REPO_ROOT / "vllm" / "utils" / "flashinfer.py"
|
|
|
|
|
MLA_ATTENTION_FILE = (
|
|
|
|
|
REPO_ROOT / "vllm" / "model_executor" / "layers" / "attention" / "mla_attention.py"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Backends to skip during doc generation
|
|
|
|
|
SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"}
|
|
|
|
|
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
def is_relevant_file(filepath: str) -> bool:
|
|
|
|
|
"""Check if a file matches any of the relevant patterns."""
|
|
|
|
|
path = Path(filepath)
|
|
|
|
|
if path.is_absolute():
|
|
|
|
|
try:
|
|
|
|
|
path = path.relative_to(REPO_ROOT)
|
|
|
|
|
except ValueError:
|
|
|
|
|
return False
|
|
|
|
|
path_str = str(path)
|
|
|
|
|
|
|
|
|
|
return any(fnmatch.fnmatch(path_str, pattern) for pattern in RELEVANT_PATTERNS)
|
|
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# AST utility helpers
|
|
|
|
|
# ---------------------------------------------------------------------------
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def find_class_in_ast(tree: ast.AST, class_name: str) -> ast.ClassDef | None:
|
|
|
|
|
"""Find a class definition in an AST."""
|
2026-01-28 17:20:22 -05:00
|
|
|
for node in ast.walk(tree):
|
2026-02-09 18:33:43 -05:00
|
|
|
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
|
|
|
|
return node
|
|
|
|
|
return None
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def find_method(node: ast.ClassDef, method_name: str) -> ast.FunctionDef | None:
|
|
|
|
|
"""Find a method in a class definition."""
|
2026-01-28 17:20:22 -05:00
|
|
|
for item in node.body:
|
2026-02-09 18:33:43 -05:00
|
|
|
if isinstance(item, ast.FunctionDef) and item.name == method_name:
|
|
|
|
|
return item
|
|
|
|
|
return None
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def method_returns_true(method: ast.FunctionDef | None) -> bool:
|
|
|
|
|
"""Check if a method simply returns True."""
|
|
|
|
|
if method is None:
|
|
|
|
|
return False
|
|
|
|
|
for node in ast.walk(method):
|
|
|
|
|
if (
|
|
|
|
|
isinstance(node, ast.Return)
|
|
|
|
|
and isinstance(node.value, ast.Constant)
|
|
|
|
|
and node.value.value is True
|
|
|
|
|
):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def check_method_overrides(node: ast.ClassDef, method_name: str) -> bool:
|
|
|
|
|
"""Check if a method is overridden and returns True."""
|
|
|
|
|
return method_returns_true(find_method(node, method_name))
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def _find_bool_class_var(class_node: ast.ClassDef, var_name: str) -> bool | None:
|
|
|
|
|
"""Find a bool class variable in a class definition. Returns None if not found."""
|
|
|
|
|
for item in class_node.body:
|
|
|
|
|
# Check for annotated assignment: attr: bool = True/False
|
|
|
|
|
if (
|
|
|
|
|
isinstance(item, ast.AnnAssign)
|
|
|
|
|
and isinstance(item.target, ast.Name)
|
|
|
|
|
and item.target.id == var_name
|
|
|
|
|
and isinstance(item.value, ast.Constant)
|
|
|
|
|
and isinstance(item.value.value, bool)
|
|
|
|
|
):
|
|
|
|
|
return item.value.value
|
|
|
|
|
# Check for plain assignment: attr = True/False
|
|
|
|
|
if isinstance(item, ast.Assign):
|
|
|
|
|
for target in item.targets:
|
2026-01-28 17:20:22 -05:00
|
|
|
if (
|
2026-02-09 18:33:43 -05:00
|
|
|
isinstance(target, ast.Name)
|
|
|
|
|
and target.id == var_name
|
|
|
|
|
and isinstance(item.value, ast.Constant)
|
|
|
|
|
and isinstance(item.value.value, bool)
|
2026-01-28 17:20:22 -05:00
|
|
|
):
|
2026-02-09 18:33:43 -05:00
|
|
|
return item.value.value
|
|
|
|
|
return None
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def _parse_list_class_var(node: ast.ClassDef, var_name: str) -> list[str] | None:
|
|
|
|
|
"""Parse a list-type class variable, returning None if not found."""
|
|
|
|
|
for item in node.body:
|
|
|
|
|
if not isinstance(item, ast.AnnAssign):
|
2026-01-28 17:20:22 -05:00
|
|
|
continue
|
2026-02-09 18:33:43 -05:00
|
|
|
if not isinstance(item.target, ast.Name):
|
|
|
|
|
continue
|
|
|
|
|
if item.target.id != var_name:
|
|
|
|
|
continue
|
|
|
|
|
if not (item.value and isinstance(item.value, ast.List)):
|
|
|
|
|
continue
|
|
|
|
|
result = []
|
|
|
|
|
for elt in item.value.elts:
|
|
|
|
|
if isinstance(elt, ast.Attribute):
|
|
|
|
|
result.append(elt.attr)
|
|
|
|
|
elif isinstance(elt, ast.Constant):
|
|
|
|
|
result.append(str(elt.value))
|
|
|
|
|
return result
|
|
|
|
|
return None
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def _parse_return_list(
|
|
|
|
|
method: ast.FunctionDef | None, handle_multiple_of: bool = False
|
|
|
|
|
) -> list[str]:
|
|
|
|
|
"""Extract list items from a method's return statement."""
|
|
|
|
|
if method is None:
|
|
|
|
|
return []
|
|
|
|
|
for stmt in ast.walk(method):
|
|
|
|
|
if not isinstance(stmt, ast.Return):
|
|
|
|
|
continue
|
|
|
|
|
if not isinstance(stmt.value, ast.List):
|
|
|
|
|
continue
|
|
|
|
|
sizes = []
|
|
|
|
|
for elt in stmt.value.elts:
|
|
|
|
|
if isinstance(elt, ast.Constant):
|
|
|
|
|
sizes.append(str(elt.value))
|
|
|
|
|
elif (
|
|
|
|
|
handle_multiple_of
|
|
|
|
|
and isinstance(elt, ast.Call)
|
|
|
|
|
and isinstance(elt.func, ast.Name)
|
|
|
|
|
and elt.func.id == "MultipleOf"
|
|
|
|
|
and elt.args
|
|
|
|
|
and isinstance(elt.args[0], ast.Constant)
|
|
|
|
|
):
|
|
|
|
|
sizes.append(f"%{elt.args[0].value}")
|
|
|
|
|
if sizes:
|
|
|
|
|
return sizes
|
|
|
|
|
return []
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def _get_parent_class_name(class_node: ast.ClassDef) -> str | None:
|
|
|
|
|
"""Get the first parent class name (simple name only).
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
Handles both simple inheritance (class Foo(Bar)) and generic
|
|
|
|
|
inheritance (class Foo(Bar[T])).
|
2026-01-28 17:20:22 -05:00
|
|
|
"""
|
2026-02-09 18:33:43 -05:00
|
|
|
if not class_node.bases:
|
|
|
|
|
return None
|
|
|
|
|
base = class_node.bases[0]
|
|
|
|
|
if isinstance(base, ast.Name):
|
|
|
|
|
return base.id
|
|
|
|
|
if isinstance(base, ast.Subscript) and isinstance(base.value, ast.Name):
|
|
|
|
|
return base.value.id
|
|
|
|
|
return None
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def _resolve_import_to_file(
|
|
|
|
|
tree: ast.AST, class_name: str, source_file: Path | None = None
|
|
|
|
|
) -> Path | None:
|
|
|
|
|
"""Try to resolve a class name to its source file via imports in the AST.
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
Handles both absolute imports (from vllm.foo import Bar) and relative
|
|
|
|
|
imports (from .foo import Bar) when source_file is provided.
|
|
|
|
|
"""
|
2026-01-28 17:20:22 -05:00
|
|
|
for node in ast.walk(tree):
|
2026-02-09 18:33:43 -05:00
|
|
|
if not isinstance(node, ast.ImportFrom):
|
2026-01-28 17:20:22 -05:00
|
|
|
continue
|
2026-02-09 18:33:43 -05:00
|
|
|
for alias in node.names:
|
|
|
|
|
actual_name = alias.asname or alias.name
|
|
|
|
|
if actual_name != class_name:
|
|
|
|
|
continue
|
|
|
|
|
if not node.module:
|
|
|
|
|
continue
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
if node.level and node.level > 0 and source_file:
|
|
|
|
|
# Relative import: resolve from the source file's directory
|
|
|
|
|
base_dir = source_file.parent
|
|
|
|
|
for _ in range(node.level - 1):
|
|
|
|
|
base_dir = base_dir.parent
|
|
|
|
|
module_path = node.module.replace(".", "/")
|
|
|
|
|
py_file = base_dir / f"{module_path}.py"
|
|
|
|
|
else:
|
|
|
|
|
# Absolute import
|
|
|
|
|
module_path = node.module.replace(".", "/")
|
|
|
|
|
py_file = REPO_ROOT / f"{module_path}.py"
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
if py_file.exists():
|
|
|
|
|
return py_file
|
|
|
|
|
return None
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None:
|
|
|
|
|
"""Find a compute capability from is_device_capability_family() calls in a function.
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
Looks for the pattern: current_platform.is_device_capability_family(N)
|
|
|
|
|
and converts N (e.g. 100) to a CC string (e.g. "10.x").
|
|
|
|
|
"""
|
2026-01-28 17:20:22 -05:00
|
|
|
for node in ast.walk(tree):
|
2026-02-09 18:33:43 -05:00
|
|
|
if not isinstance(node, ast.FunctionDef) or node.name != func_name:
|
|
|
|
|
continue
|
|
|
|
|
for n in ast.walk(node):
|
|
|
|
|
if (
|
|
|
|
|
isinstance(n, ast.Call)
|
|
|
|
|
and isinstance(n.func, ast.Attribute)
|
|
|
|
|
and n.func.attr == "is_device_capability_family"
|
|
|
|
|
and n.args
|
|
|
|
|
and isinstance(n.args[0], ast.Constant)
|
|
|
|
|
and isinstance(n.args[0].value, int)
|
|
|
|
|
):
|
|
|
|
|
return f"{n.args[0].value // 10}.x"
|
2026-01-28 17:20:22 -05:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Registry and file resolution
|
|
|
|
|
# ---------------------------------------------------------------------------
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def parse_registry() -> dict[str, str]:
|
|
|
|
|
"""Parse the registry.py file to get backend names and their class paths."""
|
|
|
|
|
tree = ast.parse(REGISTRY_FILE.read_text())
|
|
|
|
|
for node in ast.walk(tree):
|
|
|
|
|
if isinstance(node, ast.ClassDef) and node.name == "AttentionBackendEnum":
|
|
|
|
|
return _extract_enum_values(node)
|
|
|
|
|
return {}
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def _extract_enum_values(node: ast.ClassDef) -> dict[str, str]:
|
|
|
|
|
"""Extract enum name -> value mapping from a class definition."""
|
|
|
|
|
result: dict[str, str] = {}
|
2026-01-28 17:20:22 -05:00
|
|
|
for item in node.body:
|
2026-02-09 18:33:43 -05:00
|
|
|
if not isinstance(item, ast.Assign):
|
2026-01-28 17:20:22 -05:00
|
|
|
continue
|
2026-02-09 18:33:43 -05:00
|
|
|
for target in item.targets:
|
|
|
|
|
if not isinstance(target, ast.Name):
|
|
|
|
|
continue
|
|
|
|
|
if isinstance(item.value, ast.Constant) and item.value.value:
|
|
|
|
|
result[target.id] = item.value.value
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_file_from_class_path(class_path: str) -> Path | None:
|
|
|
|
|
"""Convert a class path to a file path."""
|
|
|
|
|
if not class_path:
|
|
|
|
|
return None
|
|
|
|
|
module_path = class_path.rsplit(".", 1)[0].replace(".", "/")
|
|
|
|
|
py_file = REPO_ROOT / f"{module_path}.py"
|
|
|
|
|
return py_file if py_file.exists() else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Backend feature extraction from AST
|
|
|
|
|
# ---------------------------------------------------------------------------
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_supported_dtypes(node: ast.ClassDef) -> str:
|
|
|
|
|
"""Parse supported_dtypes class variable."""
|
|
|
|
|
dtype_map = {"float16": "fp16", "bfloat16": "bf16", "float32": "fp32"}
|
|
|
|
|
dtypes = _parse_list_class_var(node, "supported_dtypes")
|
|
|
|
|
if dtypes is None:
|
|
|
|
|
return "fp16, bf16"
|
|
|
|
|
return ", ".join(dtype_map.get(d, d) for d in dtypes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_kv_cache_dtypes(node: ast.ClassDef) -> str:
|
|
|
|
|
"""Parse supported_kv_cache_dtypes class var or supports_kv_cache_dtype method."""
|
|
|
|
|
# First try the class variable
|
|
|
|
|
dtypes = _parse_list_class_var(node, "supported_kv_cache_dtypes")
|
|
|
|
|
if dtypes:
|
|
|
|
|
return ", ".join(dtypes)
|
|
|
|
|
|
|
|
|
|
# Fall back to parsing the supports_kv_cache_dtype method
|
|
|
|
|
# Look for `kv_cache_dtype in ["auto", "bfloat16"]` pattern
|
|
|
|
|
method = find_method(node, "supports_kv_cache_dtype")
|
|
|
|
|
if method:
|
|
|
|
|
for n in ast.walk(method):
|
|
|
|
|
if (
|
|
|
|
|
isinstance(n, ast.Compare)
|
|
|
|
|
and len(n.ops) == 1
|
|
|
|
|
and isinstance(n.ops[0], ast.In)
|
|
|
|
|
and len(n.comparators) == 1
|
|
|
|
|
and isinstance(n.comparators[0], ast.List)
|
|
|
|
|
):
|
|
|
|
|
dtypes = [
|
|
|
|
|
e.value
|
|
|
|
|
for e in n.comparators[0].elts
|
|
|
|
|
if isinstance(e, ast.Constant) and isinstance(e.value, str)
|
|
|
|
|
]
|
|
|
|
|
if dtypes:
|
|
|
|
|
return ", ".join(dtypes)
|
|
|
|
|
|
|
|
|
|
return "auto"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_block_sizes(node: ast.ClassDef) -> str:
|
|
|
|
|
"""Parse get_supported_kernel_block_sizes method."""
|
|
|
|
|
method = find_method(node, "get_supported_kernel_block_sizes")
|
|
|
|
|
sizes = _parse_return_list(method, handle_multiple_of=True)
|
|
|
|
|
return ", ".join(sizes) if sizes else "Any"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_head_sizes(node: ast.ClassDef) -> str:
|
|
|
|
|
"""Parse get_supported_head_sizes method."""
|
|
|
|
|
method = find_method(node, "get_supported_head_sizes")
|
|
|
|
|
sizes = _parse_return_list(method)
|
|
|
|
|
return ", ".join(sizes) if sizes else "Any"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_compute_capability(node: ast.ClassDef) -> str:
|
|
|
|
|
"""Parse supports_compute_capability method."""
|
|
|
|
|
method = find_method(node, "supports_compute_capability")
|
|
|
|
|
if method is None:
|
|
|
|
|
return "Any"
|
|
|
|
|
|
|
|
|
|
min_cap: tuple[int, int] | None = None
|
|
|
|
|
max_cap: tuple[int, int] | None = None
|
|
|
|
|
major_list: list[int] = []
|
|
|
|
|
|
|
|
|
|
for n in ast.walk(method):
|
|
|
|
|
if not isinstance(n, ast.Compare):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Handle `capability >= DeviceCapability(...)` or `capability <= ...`
|
|
|
|
|
for op, comp in zip(n.ops, n.comparators):
|
|
|
|
|
if not (
|
|
|
|
|
isinstance(comp, ast.Call)
|
|
|
|
|
and isinstance(comp.func, ast.Name)
|
|
|
|
|
and comp.func.id == "DeviceCapability"
|
|
|
|
|
and comp.args
|
|
|
|
|
and isinstance(comp.args[0], ast.Constant)
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
major = comp.args[0].value
|
|
|
|
|
minor = 0
|
|
|
|
|
if len(comp.args) > 1 and isinstance(comp.args[1], ast.Constant):
|
|
|
|
|
minor = comp.args[1].value
|
|
|
|
|
if isinstance(op, ast.GtE):
|
|
|
|
|
min_cap = (major, minor)
|
|
|
|
|
elif isinstance(op, ast.LtE):
|
|
|
|
|
max_cap = (major, minor)
|
|
|
|
|
|
|
|
|
|
# Handle `capability.major == N` or `capability.major in [N, M]`
|
|
|
|
|
if (
|
|
|
|
|
isinstance(n.left, ast.Attribute)
|
|
|
|
|
and n.left.attr == "major"
|
|
|
|
|
and len(n.ops) == 1
|
|
|
|
|
and len(n.comparators) == 1
|
|
|
|
|
):
|
|
|
|
|
comp = n.comparators[0]
|
|
|
|
|
if isinstance(n.ops[0], ast.Eq) and isinstance(comp, ast.Constant):
|
|
|
|
|
major_list.append(comp.value)
|
|
|
|
|
elif isinstance(n.ops[0], ast.In) and isinstance(comp, ast.List):
|
|
|
|
|
major_list.extend(
|
|
|
|
|
e.value
|
|
|
|
|
for e in comp.elts
|
|
|
|
|
if isinstance(e, ast.Constant) and isinstance(e.value, int)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if major_list:
|
|
|
|
|
major_list.sort()
|
|
|
|
|
if len(major_list) == 1:
|
|
|
|
|
return f"{major_list[0]}.x"
|
|
|
|
|
return f"{major_list[0]}.x-{major_list[-1]}.x"
|
|
|
|
|
|
|
|
|
|
if min_cap:
|
|
|
|
|
if max_cap:
|
|
|
|
|
return f"{min_cap[0]}.x-{max_cap[0]}.x"
|
|
|
|
|
return f"≥{min_cap[0]}.{min_cap[1]}"
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
return "Any"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_attention_types(node: ast.ClassDef) -> str:
|
|
|
|
|
"""Parse supports_attn_type method."""
|
|
|
|
|
method = find_method(node, "supports_attn_type")
|
|
|
|
|
if method is None:
|
|
|
|
|
return "Decoder"
|
|
|
|
|
|
|
|
|
|
type_map = {
|
|
|
|
|
"DECODER": "Decoder",
|
|
|
|
|
"ENCODER": "Encoder",
|
|
|
|
|
"ENCODER_ONLY": "Encoder Only",
|
|
|
|
|
"ENCODER_DECODER": "Enc-Dec",
|
|
|
|
|
}
|
|
|
|
|
types: set[str] = set()
|
|
|
|
|
|
|
|
|
|
for n in ast.walk(method):
|
|
|
|
|
# Handle `attn_type in (AttentionType.DECODER, ...)`
|
|
|
|
|
if not (
|
|
|
|
|
isinstance(n, ast.Compare)
|
|
|
|
|
and len(n.ops) == 1
|
|
|
|
|
and isinstance(n.ops[0], ast.In)
|
|
|
|
|
and len(n.comparators) == 1
|
|
|
|
|
and isinstance(n.comparators[0], ast.Tuple | ast.Set)
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
for elt in n.comparators[0].elts:
|
|
|
|
|
if isinstance(elt, ast.Attribute) and elt.attr in type_map:
|
|
|
|
|
types.add(type_map[elt.attr])
|
|
|
|
|
|
|
|
|
|
if not types:
|
|
|
|
|
return "Decoder"
|
|
|
|
|
return "All" if len(types) >= 3 else ", ".join(sorted(types))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_impl_bool_attr(
|
|
|
|
|
tree: ast.AST,
|
|
|
|
|
class_name: str,
|
|
|
|
|
attr_name: str,
|
|
|
|
|
default: bool = False,
|
|
|
|
|
source_file: Path | None = None,
|
|
|
|
|
_visited: set[str] | None = None,
|
|
|
|
|
) -> bool:
|
|
|
|
|
"""Parse a boolean class attribute from an impl class, following inheritance.
|
|
|
|
|
|
|
|
|
|
Walks up the inheritance chain within the same file and across files
|
|
|
|
|
(by resolving imports) to find the attribute value.
|
|
|
|
|
"""
|
|
|
|
|
if _visited is None:
|
|
|
|
|
_visited = set()
|
|
|
|
|
if class_name in _visited:
|
|
|
|
|
return default
|
|
|
|
|
_visited.add(class_name)
|
|
|
|
|
|
|
|
|
|
class_node = find_class_in_ast(tree, class_name)
|
|
|
|
|
if class_node is None:
|
|
|
|
|
return default
|
|
|
|
|
|
|
|
|
|
# Check directly on this class
|
|
|
|
|
value = _find_bool_class_var(class_node, attr_name)
|
|
|
|
|
if value is not None:
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
# Check parent class
|
|
|
|
|
parent_name = _get_parent_class_name(class_node)
|
|
|
|
|
if parent_name:
|
|
|
|
|
# Try parent in same file first
|
|
|
|
|
parent_node = find_class_in_ast(tree, parent_name)
|
|
|
|
|
if parent_node is not None:
|
|
|
|
|
return parse_impl_bool_attr(
|
|
|
|
|
tree, parent_name, attr_name, default, source_file, _visited
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Try resolving cross-file import
|
|
|
|
|
parent_file = _resolve_import_to_file(tree, parent_name, source_file)
|
|
|
|
|
if parent_file:
|
|
|
|
|
try:
|
|
|
|
|
parent_tree = ast.parse(parent_file.read_text())
|
|
|
|
|
return parse_impl_bool_attr(
|
|
|
|
|
parent_tree,
|
|
|
|
|
parent_name,
|
|
|
|
|
attr_name,
|
|
|
|
|
default,
|
|
|
|
|
parent_file,
|
|
|
|
|
_visited,
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
return default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None:
|
|
|
|
|
"""Analyze a backend class and extract feature information."""
|
|
|
|
|
file_path = get_file_from_class_path(class_path)
|
|
|
|
|
if file_path is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
tree = ast.parse(file_path.read_text())
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f" Warning: Could not parse {file_path}: {e}", file=sys.stderr)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
class_name = class_path.rsplit(".", 1)[1]
|
|
|
|
|
class_node = find_class_in_ast(tree, class_name)
|
|
|
|
|
if class_node is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Check if this is an MLA backend by parent class or naming
|
|
|
|
|
parent = _get_parent_class_name(class_node)
|
|
|
|
|
mla_parents = {"MLACommonBackend", "FlashMLABackend", "FlashMLASparseBackend"}
|
|
|
|
|
is_mla_backend = (
|
|
|
|
|
parent in mla_parents
|
|
|
|
|
or ".mla." in class_path.lower()
|
|
|
|
|
or "_mla" in backend_name.lower()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Determine compute capability - use N/A for non-CUDA backends
|
|
|
|
|
is_non_cuda = backend_name.startswith(("CPU_", "ROCM_"))
|
|
|
|
|
compute_cap = "N/A" if is_non_cuda else parse_compute_capability(class_node)
|
|
|
|
|
|
|
|
|
|
# Parse impl class features (DCP support)
|
|
|
|
|
impl_method = find_method(class_node, "get_impl_cls")
|
|
|
|
|
impl_class_name = None
|
|
|
|
|
if impl_method:
|
|
|
|
|
for stmt in ast.walk(impl_method):
|
|
|
|
|
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name):
|
|
|
|
|
impl_class_name = stmt.value.id
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
supports_dcp = False
|
|
|
|
|
if impl_class_name:
|
|
|
|
|
supports_dcp = parse_impl_bool_attr(
|
|
|
|
|
tree, impl_class_name, "can_return_lse_for_decode", False, file_path
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"name": backend_name,
|
|
|
|
|
"dtypes": parse_supported_dtypes(class_node),
|
|
|
|
|
"kv_cache_dtypes": parse_kv_cache_dtypes(class_node),
|
|
|
|
|
"block_sizes": parse_block_sizes(class_node),
|
|
|
|
|
"head_sizes": parse_head_sizes(class_node),
|
|
|
|
|
"attn_types": parse_attention_types(class_node),
|
|
|
|
|
"compute_capability": compute_cap,
|
|
|
|
|
"is_mla": is_mla_backend or check_method_overrides(class_node, "is_mla"),
|
|
|
|
|
"supports_sink": check_method_overrides(class_node, "supports_sink"),
|
|
|
|
|
"is_sparse": check_method_overrides(class_node, "is_sparse"),
|
|
|
|
|
"supports_mm_prefix": check_method_overrides(class_node, "supports_mm_prefix"),
|
|
|
|
|
"supports_dcp": supports_dcp,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
2026-03-01 18:44:57 -05:00
|
|
|
# Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill)
|
2026-02-09 18:33:43 -05:00
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
2026-03-01 18:44:57 -05:00
|
|
|
def _parse_fa4_supported_caps() -> str | None:
|
|
|
|
|
"""Parse flash_attn_interface.py for FA4 supported compute capabilities.
|
|
|
|
|
|
|
|
|
|
Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported().
|
|
|
|
|
"""
|
|
|
|
|
fa_interface_file = (
|
|
|
|
|
REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py"
|
|
|
|
|
)
|
|
|
|
|
if not fa_interface_file.exists():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
tree = ast.parse(fa_interface_file.read_text())
|
|
|
|
|
except Exception:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
for node in ast.walk(tree):
|
|
|
|
|
if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported":
|
|
|
|
|
continue
|
|
|
|
|
for n in ast.walk(node):
|
|
|
|
|
if not (
|
|
|
|
|
isinstance(n, ast.Compare)
|
|
|
|
|
and len(n.ops) == 1
|
|
|
|
|
and isinstance(n.ops[0], ast.NotIn)
|
|
|
|
|
and isinstance(n.comparators[0], ast.List)
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
caps: list[int] = [
|
|
|
|
|
e.value
|
|
|
|
|
for e in n.comparators[0].elts
|
|
|
|
|
if isinstance(e, ast.Constant) and isinstance(e.value, int)
|
|
|
|
|
]
|
|
|
|
|
if caps:
|
|
|
|
|
caps.sort()
|
|
|
|
|
return f"{caps[0]}.x-{caps[-1]}.x"
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
2026-03-01 18:44:57 -05:00
|
|
|
"""Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences.
|
2026-02-09 18:33:43 -05:00
|
|
|
|
2026-03-01 18:44:57 -05:00
|
|
|
Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective
|
2026-02-09 18:33:43 -05:00
|
|
|
feature overrides for compute capability, KV cache dtypes, and sink support.
|
|
|
|
|
"""
|
|
|
|
|
if not FA_UTILS_FILE.exists():
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
tree = ast.parse(FA_UTILS_FILE.read_text())
|
|
|
|
|
except Exception:
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
# Analyze the functions to determine FA3-specific features
|
|
|
|
|
fa3_supports_fp8 = False
|
|
|
|
|
fa3_supports_sinks = False
|
|
|
|
|
fa3_compute_cap: str | None = None
|
2026-03-01 18:44:57 -05:00
|
|
|
fa4_compute_cap: str | None = None
|
2026-02-09 18:33:43 -05:00
|
|
|
|
|
|
|
|
for node in ast.walk(tree):
|
|
|
|
|
if not isinstance(node, ast.FunctionDef):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Check flash_attn_supports_fp8 - looks for `get_flash_attn_version() == 3`
|
|
|
|
|
if node.name == "flash_attn_supports_fp8":
|
|
|
|
|
for n in ast.walk(node):
|
|
|
|
|
if (
|
|
|
|
|
isinstance(n, ast.Compare)
|
|
|
|
|
and isinstance(n.left, ast.Call)
|
|
|
|
|
and isinstance(n.left.func, ast.Name)
|
|
|
|
|
and n.left.func.id == "get_flash_attn_version"
|
|
|
|
|
):
|
|
|
|
|
fa3_supports_fp8 = True
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3`
|
|
|
|
|
if node.name == "flash_attn_supports_sinks":
|
|
|
|
|
for n in ast.walk(node):
|
|
|
|
|
if (
|
|
|
|
|
isinstance(n, ast.Compare)
|
|
|
|
|
and isinstance(n.left, ast.Call)
|
|
|
|
|
and isinstance(n.left.func, ast.Name)
|
|
|
|
|
and n.left.func.id == "get_flash_attn_version"
|
|
|
|
|
):
|
|
|
|
|
fa3_supports_sinks = True
|
|
|
|
|
break
|
|
|
|
|
|
2026-03-01 18:44:57 -05:00
|
|
|
# Check get_flash_attn_version for FA3/FA4 compute capability
|
2026-02-09 18:33:43 -05:00
|
|
|
if node.name == "get_flash_attn_version":
|
|
|
|
|
for n in ast.walk(node):
|
2026-03-01 18:44:57 -05:00
|
|
|
# Handle IfExp (ternary) with `device_capability.major == 9`
|
2026-02-09 18:33:43 -05:00
|
|
|
if isinstance(n, ast.IfExp):
|
|
|
|
|
test = n.test
|
|
|
|
|
if isinstance(test, ast.BoolOp):
|
|
|
|
|
for val in test.values:
|
|
|
|
|
if (
|
|
|
|
|
isinstance(val, ast.Compare)
|
|
|
|
|
and isinstance(val.left, ast.Attribute)
|
|
|
|
|
and val.left.attr == "major"
|
|
|
|
|
and val.comparators
|
|
|
|
|
and isinstance(val.comparators[0], ast.Constant)
|
|
|
|
|
):
|
|
|
|
|
fa3_compute_cap = f"{val.comparators[0].value}.x"
|
|
|
|
|
break
|
|
|
|
|
|
2026-03-01 18:44:57 -05:00
|
|
|
# Handle If statements for FA3/FA4 detection
|
|
|
|
|
# e.g. `if device_capability.major == 9` -> FA3
|
|
|
|
|
# `elif device_capability.major >= 10` -> FA4
|
|
|
|
|
if isinstance(n, ast.If):
|
|
|
|
|
test = n.test
|
|
|
|
|
comparisons = (
|
|
|
|
|
[v for v in test.values if isinstance(v, ast.Compare)]
|
|
|
|
|
if isinstance(test, ast.BoolOp)
|
|
|
|
|
else [test]
|
|
|
|
|
if isinstance(test, ast.Compare)
|
|
|
|
|
else []
|
|
|
|
|
)
|
|
|
|
|
for comp in comparisons:
|
|
|
|
|
if not (
|
|
|
|
|
isinstance(comp.left, ast.Attribute)
|
|
|
|
|
and comp.left.attr == "major"
|
|
|
|
|
and comp.comparators
|
|
|
|
|
and isinstance(comp.comparators[0], ast.Constant)
|
|
|
|
|
and isinstance(comp.comparators[0].value, int)
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
op = comp.ops[0]
|
|
|
|
|
val = comp.comparators[0].value
|
|
|
|
|
if isinstance(op, ast.Eq) and fa3_compute_cap is None:
|
|
|
|
|
fa3_compute_cap = f"{val}.x"
|
|
|
|
|
elif isinstance(op, ast.GtE) and fa4_compute_cap is None:
|
|
|
|
|
fa4_compute_cap = f"≥{val}.0"
|
|
|
|
|
|
|
|
|
|
# Fallback: try to parse FA4 compute caps from flash_attn_interface.py
|
|
|
|
|
if fa4_compute_cap is None:
|
|
|
|
|
fa4_compute_cap = _parse_fa4_supported_caps()
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
return {
|
|
|
|
|
"fa2": {
|
|
|
|
|
"supports_fp8": False,
|
|
|
|
|
"supports_sink": False,
|
|
|
|
|
},
|
|
|
|
|
"fa3": {
|
|
|
|
|
"compute_capability": fa3_compute_cap,
|
|
|
|
|
"supports_fp8": fa3_supports_fp8,
|
|
|
|
|
"supports_sink": fa3_supports_sinks,
|
|
|
|
|
},
|
2026-03-01 18:44:57 -05:00
|
|
|
"fa4": {
|
|
|
|
|
"compute_capability": fa4_compute_cap,
|
|
|
|
|
"supports_fp8": False,
|
|
|
|
|
"supports_sink": False,
|
|
|
|
|
},
|
2026-02-09 18:33:43 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
|
|
|
|
|
"""Parse flashinfer.py to detect TRTLLM-specific features.
|
|
|
|
|
|
|
|
|
|
FLASHINFER uses TRTLLM attention on SM100 (Blackwell), which has different
|
|
|
|
|
capabilities (e.g., sink support) than native FlashInfer on earlier GPUs.
|
|
|
|
|
"""
|
|
|
|
|
if not FLASHINFER_UTILS_FILE.exists():
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
tree = ast.parse(FLASHINFER_UTILS_FILE.read_text())
|
|
|
|
|
except Exception:
|
|
|
|
|
return {}
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
trtllm_compute_cap = _find_cc_in_function(tree, "supports_trtllm_attention")
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
if not trtllm_compute_cap:
|
|
|
|
|
return {}
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
return {
|
|
|
|
|
"native": {
|
|
|
|
|
# Native FlashInfer: everything except SM100
|
|
|
|
|
"supports_sink": False,
|
|
|
|
|
},
|
|
|
|
|
"trtllm": {
|
|
|
|
|
# TRTLLM pathway on Blackwell
|
|
|
|
|
"compute_capability": trtllm_compute_cap,
|
|
|
|
|
"supports_sink": True,
|
|
|
|
|
},
|
2026-01-28 17:20:22 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def parse_mla_prefill_backends() -> list[dict[str, Any]]:
|
|
|
|
|
"""Parse MLA prefill backend options from mla_attention.py.
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
MLA uses different backends for prefill vs decode. The decode backends are
|
|
|
|
|
registered in the registry, but prefill backends are selected at runtime
|
|
|
|
|
based on conditions in MLACommonImpl.__init__.
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
Returns a list of prefill backend info dicts with their requirements.
|
|
|
|
|
"""
|
|
|
|
|
if not MLA_ATTENTION_FILE.exists():
|
|
|
|
|
return []
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
try:
|
|
|
|
|
tree = ast.parse(MLA_ATTENTION_FILE.read_text())
|
|
|
|
|
except Exception:
|
|
|
|
|
return []
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# Find compute capability requirements by parsing use_* functions
|
|
|
|
|
trtllm_cc = _find_cc_in_function(tree, "use_trtllm_ragged_deepseek_prefill")
|
|
|
|
|
flashinfer_cc = _find_cc_in_function(tree, "use_flashinfer_prefill")
|
|
|
|
|
cudnn_cc = _find_cc_in_function(tree, "use_cudnn_prefill")
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# Build prefill backend list based on what we found
|
|
|
|
|
# Order matches the priority in MLACommonImpl.__init__
|
|
|
|
|
prefill_backends: list[dict[str, Any]] = []
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# TRT-LLM Ragged (highest priority if available)
|
|
|
|
|
if trtllm_cc:
|
|
|
|
|
prefill_backends.append(
|
|
|
|
|
{
|
|
|
|
|
"name": "TRT-LLM Ragged‡",
|
|
|
|
|
"description": "TensorRT-LLM ragged attention",
|
|
|
|
|
"compute_capability": trtllm_cc,
|
|
|
|
|
"enable": "Default on SM100",
|
|
|
|
|
"disable": "`-ac.use_trtllm_ragged_deepseek_prefill=0`",
|
|
|
|
|
"notes": "DeepSeek R1 dims only",
|
|
|
|
|
}
|
|
|
|
|
)
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# FlashInfer prefill
|
|
|
|
|
if flashinfer_cc:
|
|
|
|
|
prefill_backends.append(
|
|
|
|
|
{
|
|
|
|
|
"name": "FlashInfer",
|
|
|
|
|
"description": "FlashInfer CUTLASS backend",
|
|
|
|
|
"compute_capability": flashinfer_cc,
|
|
|
|
|
"enable": "`-ac.disable_flashinfer_prefill=0`",
|
|
|
|
|
"disable": "`-ac.disable_flashinfer_prefill=1`",
|
|
|
|
|
"notes": "DeepSeek R1 dims only",
|
|
|
|
|
}
|
|
|
|
|
)
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# cuDNN prefill
|
|
|
|
|
if cudnn_cc:
|
|
|
|
|
prefill_backends.append(
|
|
|
|
|
{
|
|
|
|
|
"name": "cuDNN",
|
|
|
|
|
"description": "cuDNN-based attention",
|
|
|
|
|
"compute_capability": cudnn_cc,
|
|
|
|
|
"enable": "`-ac.use_cudnn_prefill=1`",
|
|
|
|
|
"disable": "`-ac.use_cudnn_prefill=0`",
|
|
|
|
|
"notes": "",
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# FlashAttention is always available as fallback
|
|
|
|
|
prefill_backends.append(
|
|
|
|
|
{
|
|
|
|
|
"name": "FlashAttention",
|
|
|
|
|
"description": "FlashAttention varlen (FA2/FA3)",
|
|
|
|
|
"compute_capability": "Any",
|
|
|
|
|
"enable": "Default fallback",
|
|
|
|
|
"disable": "Use other backends",
|
|
|
|
|
"notes": "FA3 on SM90, FA2 otherwise",
|
|
|
|
|
}
|
2026-01-28 17:20:22 -05:00
|
|
|
)
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
return prefill_backends
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# ---------------------------------------------------------------------------
|
2026-03-01 18:44:57 -05:00
|
|
|
# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
|
2026-02-09 18:33:43 -05:00
|
|
|
# ---------------------------------------------------------------------------
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def _expand_flash_attn_variants(
|
|
|
|
|
all_backends: list[dict[str, Any]],
|
|
|
|
|
fa_features: dict[str, dict[str, Any]],
|
|
|
|
|
) -> list[dict[str, Any]]:
|
2026-03-01 18:44:57 -05:00
|
|
|
"""Expand FLASH_ATTN into FA2, FA3, and FA4 variants."""
|
2026-02-09 18:33:43 -05:00
|
|
|
expanded = []
|
|
|
|
|
for backend in all_backends:
|
|
|
|
|
if backend["name"] != "FLASH_ATTN":
|
|
|
|
|
backend.setdefault("_sort_key", backend["name"])
|
|
|
|
|
backend.setdefault("_sort_order", 0)
|
|
|
|
|
backend.setdefault("version", "")
|
|
|
|
|
expanded.append(backend)
|
|
|
|
|
continue
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# Create FA2 entry (keeps base backend's compute_capability)
|
|
|
|
|
fa2 = backend.copy()
|
|
|
|
|
fa2["version"] = "FA2*"
|
|
|
|
|
fa2["_sort_key"] = "FLASH_ATTN"
|
|
|
|
|
fa2["_sort_order"] = 0
|
|
|
|
|
fa2["supports_sink"] = fa_features["fa2"]["supports_sink"]
|
|
|
|
|
|
|
|
|
|
# Create FA3 entry (uses parsed compute_capability from fa_utils)
|
|
|
|
|
fa3 = backend.copy()
|
|
|
|
|
fa3["version"] = "FA3*"
|
|
|
|
|
fa3["_sort_key"] = "FLASH_ATTN"
|
|
|
|
|
fa3["_sort_order"] = 1
|
|
|
|
|
if fa_features["fa3"]["compute_capability"]:
|
|
|
|
|
fa3["compute_capability"] = fa_features["fa3"]["compute_capability"]
|
|
|
|
|
fa3["supports_sink"] = fa_features["fa3"]["supports_sink"]
|
|
|
|
|
if fa_features["fa3"]["supports_fp8"]:
|
|
|
|
|
base_dtypes = backend["kv_cache_dtypes"].split(", ")
|
|
|
|
|
fp8_dtypes = ["fp8", "fp8_e4m3", "fp8_e5m2"]
|
|
|
|
|
new_dtypes = [d for d in fp8_dtypes if d not in base_dtypes]
|
|
|
|
|
fa3["kv_cache_dtypes"] = ", ".join(base_dtypes + new_dtypes)
|
|
|
|
|
|
|
|
|
|
expanded.append(fa2)
|
|
|
|
|
expanded.append(fa3)
|
2026-03-01 18:44:57 -05:00
|
|
|
|
|
|
|
|
# Create FA4 entry if FA4 features are available
|
|
|
|
|
if "fa4" in fa_features:
|
|
|
|
|
fa4 = backend.copy()
|
|
|
|
|
fa4["version"] = "FA4*"
|
|
|
|
|
fa4["_sort_key"] = "FLASH_ATTN"
|
|
|
|
|
fa4["_sort_order"] = 2
|
|
|
|
|
if fa_features["fa4"].get("compute_capability"):
|
|
|
|
|
fa4["compute_capability"] = fa_features["fa4"]["compute_capability"]
|
|
|
|
|
fa4["supports_sink"] = fa_features["fa4"]["supports_sink"]
|
|
|
|
|
expanded.append(fa4)
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
return expanded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _expand_flashinfer_variants(
|
|
|
|
|
all_backends: list[dict[str, Any]],
|
|
|
|
|
fi_features: dict[str, dict[str, Any]],
|
|
|
|
|
) -> list[dict[str, Any]]:
|
|
|
|
|
"""Expand FLASHINFER into native and TRTLLM variants."""
|
|
|
|
|
expanded = []
|
|
|
|
|
for backend in all_backends:
|
|
|
|
|
if backend["name"] != "FLASHINFER":
|
|
|
|
|
expanded.append(backend)
|
|
|
|
|
continue
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# Parse original compute capability to get min CC
|
|
|
|
|
orig_cap = backend["compute_capability"]
|
|
|
|
|
parts = orig_cap.replace(".x", "").split("-")
|
|
|
|
|
min_cc = parts[0] if parts else "7"
|
|
|
|
|
trtllm_cc = fi_features["trtllm"]["compute_capability"]
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# Create native entry (pre-Blackwell GPUs)
|
|
|
|
|
native = backend.copy()
|
|
|
|
|
native["version"] = "Native†"
|
|
|
|
|
native["_sort_key"] = "FLASHINFER"
|
|
|
|
|
native["_sort_order"] = 0
|
|
|
|
|
native["supports_sink"] = fi_features["native"]["supports_sink"]
|
|
|
|
|
native["compute_capability"] = f"{min_cc}.x-9.x"
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# Create TRTLLM entry
|
|
|
|
|
trtllm = backend.copy()
|
|
|
|
|
trtllm["version"] = "TRTLLM†"
|
|
|
|
|
trtllm["_sort_key"] = "FLASHINFER"
|
|
|
|
|
trtllm["_sort_order"] = 1
|
|
|
|
|
trtllm["compute_capability"] = trtllm_cc
|
|
|
|
|
trtllm["supports_sink"] = fi_features["trtllm"]["supports_sink"]
|
2026-01-28 17:20:22 -05:00
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
expanded.append(native)
|
|
|
|
|
expanded.append(trtllm)
|
|
|
|
|
return expanded
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# CUDA priority list parsing
|
|
|
|
|
# ---------------------------------------------------------------------------
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_cuda_priority_lists() -> dict[str, list[str]]:
|
|
|
|
|
"""Parse priority lists from cuda.py using AST.
|
|
|
|
|
|
|
|
|
|
The structure of _get_backend_priorities is:
|
|
|
|
|
if use_mla:
|
|
|
|
|
if device_capability.major == 10:
|
|
|
|
|
return [MLA list for SM100]
|
|
|
|
|
else:
|
|
|
|
|
return [MLA list for default]
|
|
|
|
|
else:
|
|
|
|
|
if device_capability.major == 10:
|
|
|
|
|
return [Standard list for SM100]
|
|
|
|
|
else:
|
|
|
|
|
return [Standard list for default]
|
|
|
|
|
"""
|
|
|
|
|
if not CUDA_PLATFORM_FILE.exists():
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
source = CUDA_PLATFORM_FILE.read_text()
|
|
|
|
|
tree = ast.parse(source)
|
|
|
|
|
except Exception:
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
priorities: dict[str, list[str]] = {}
|
|
|
|
|
|
|
|
|
|
# Find the _get_backend_priorities function
|
|
|
|
|
for node in ast.walk(tree):
|
|
|
|
|
if not isinstance(node, ast.FunctionDef):
|
|
|
|
|
continue
|
|
|
|
|
if node.name != "_get_backend_priorities":
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Process the function body directly
|
|
|
|
|
for stmt in node.body:
|
|
|
|
|
if not isinstance(stmt, ast.If):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Check if this is the "if use_mla:" branch
|
|
|
|
|
is_mla_branch = (
|
|
|
|
|
isinstance(stmt.test, ast.Name) and stmt.test.id == "use_mla"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if is_mla_branch:
|
|
|
|
|
_extract_priorities(stmt.body, priorities, "mla")
|
|
|
|
|
if stmt.orelse:
|
|
|
|
|
_extract_priorities(stmt.orelse, priorities, "standard")
|
|
|
|
|
else:
|
|
|
|
|
_extract_priorities([stmt], priorities, "standard")
|
|
|
|
|
|
|
|
|
|
return priorities
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_backends_from_return(stmts: list) -> list[str]:
|
2026-02-12 12:21:54 -05:00
|
|
|
"""Extract backend names from return statements in a list of statements.
|
|
|
|
|
|
|
|
|
|
Handles starred unpacking (e.g. ``*sparse_backends``) by resolving the
|
|
|
|
|
variable from assignments found in the same statement list. When the
|
|
|
|
|
variable is conditionally assigned (inside an ``if/else``), the ``else``
|
|
|
|
|
branch value is used as the representative default.
|
|
|
|
|
"""
|
|
|
|
|
# Collect variable assignments so we can resolve starred expressions.
|
|
|
|
|
# For conditional assignments, last-written (else branch) wins.
|
|
|
|
|
var_assigns: dict[str, list[str]] = {}
|
|
|
|
|
for stmt in stmts:
|
|
|
|
|
if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.List):
|
|
|
|
|
for target in stmt.targets:
|
|
|
|
|
if isinstance(target, ast.Name):
|
|
|
|
|
var_assigns[target.id] = [
|
|
|
|
|
e.attr for e in stmt.value.elts if isinstance(e, ast.Attribute)
|
|
|
|
|
]
|
|
|
|
|
elif isinstance(stmt, ast.If):
|
|
|
|
|
for branch in (stmt.body, stmt.orelse):
|
|
|
|
|
for branch_stmt in branch:
|
|
|
|
|
if isinstance(branch_stmt, ast.Assign) and isinstance(
|
|
|
|
|
branch_stmt.value, ast.List
|
|
|
|
|
):
|
|
|
|
|
for target in branch_stmt.targets:
|
|
|
|
|
if isinstance(target, ast.Name):
|
|
|
|
|
var_assigns[target.id] = [
|
|
|
|
|
e.attr
|
|
|
|
|
for e in branch_stmt.value.elts
|
|
|
|
|
if isinstance(e, ast.Attribute)
|
|
|
|
|
]
|
|
|
|
|
|
2026-01-28 17:20:22 -05:00
|
|
|
for stmt in stmts:
|
|
|
|
|
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.List):
|
2026-02-12 12:21:54 -05:00
|
|
|
backends: list[str] = []
|
|
|
|
|
for e in stmt.value.elts:
|
|
|
|
|
if isinstance(e, ast.Attribute):
|
|
|
|
|
backends.append(e.attr)
|
|
|
|
|
elif (
|
|
|
|
|
isinstance(e, ast.Starred)
|
|
|
|
|
and isinstance(e.value, ast.Name)
|
|
|
|
|
and e.value.id in var_assigns
|
|
|
|
|
):
|
|
|
|
|
backends.extend(var_assigns[e.value.id])
|
|
|
|
|
return backends
|
2026-01-28 17:20:22 -05:00
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_sm100_check(test: ast.expr) -> bool:
|
|
|
|
|
"""Check if test is `something.major == 10`."""
|
|
|
|
|
return (
|
|
|
|
|
isinstance(test, ast.Compare)
|
|
|
|
|
and isinstance(test.left, ast.Attribute)
|
|
|
|
|
and test.left.attr == "major"
|
|
|
|
|
and len(test.ops) == 1
|
|
|
|
|
and isinstance(test.ops[0], ast.Eq)
|
|
|
|
|
and len(test.comparators) == 1
|
|
|
|
|
and isinstance(test.comparators[0], ast.Constant)
|
|
|
|
|
and test.comparators[0].value == 10
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_priorities(body: list, priorities: dict[str, list[str]], prefix: str):
|
|
|
|
|
"""Extract priority lists from if/else statement body."""
|
|
|
|
|
for stmt in body:
|
|
|
|
|
if isinstance(stmt, ast.If):
|
|
|
|
|
is_sm100 = _is_sm100_check(stmt.test)
|
|
|
|
|
if_key = f"{prefix}_sm100" if is_sm100 else f"{prefix}_default"
|
|
|
|
|
else_key = f"{prefix}_default" if is_sm100 else f"{prefix}_sm100"
|
|
|
|
|
|
|
|
|
|
if backends := _get_backends_from_return(stmt.body):
|
|
|
|
|
priorities[if_key] = backends
|
|
|
|
|
if backends := _get_backends_from_return(stmt.orelse):
|
|
|
|
|
priorities[else_key] = backends
|
|
|
|
|
|
|
|
|
|
elif isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.List):
|
|
|
|
|
backends = [e.attr for e in stmt.value.elts if isinstance(e, ast.Attribute)]
|
|
|
|
|
priorities[f"{prefix}_default"] = backends
|
|
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Data-driven table rendering
|
|
|
|
|
#
|
|
|
|
|
# Each column is a (header, formatter) pair. The formatter takes a backend
|
|
|
|
|
# info dict and returns the cell string. Tables are assembled by selecting
|
|
|
|
|
# which columns to include, then calling _render_table().
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
# Column type alias for readability
|
|
|
|
|
TableColumn = tuple[str, Callable[[dict[str, Any]], str]]
|
|
|
|
|
|
|
|
|
|
# Shared column definitions -- order here matches the output table order
|
|
|
|
|
_COL_BACKEND: TableColumn = ("Backend", lambda b: f"`{b['name']}`")
|
|
|
|
|
_COL_VERSION: TableColumn = ("Version", lambda b: b.get("version", ""))
|
|
|
|
|
_COL_DTYPES: TableColumn = ("Dtypes", lambda b: b["dtypes"])
|
|
|
|
|
_COL_KV_DTYPES: TableColumn = (
|
|
|
|
|
"KV Dtypes",
|
|
|
|
|
lambda b: add_literal_quotes(b["kv_cache_dtypes"]),
|
|
|
|
|
)
|
|
|
|
|
_COL_BLOCK_SIZES: TableColumn = ("Block Sizes", lambda b: b["block_sizes"])
|
|
|
|
|
_COL_HEAD_SIZES: TableColumn = ("Head Sizes", lambda b: b["head_sizes"])
|
|
|
|
|
_COL_SINK: TableColumn = ("Sink", lambda b: bool_to_emoji(b["supports_sink"]))
|
|
|
|
|
_COL_SPARSE: TableColumn = ("Sparse", lambda b: bool_to_emoji(b["is_sparse"]))
|
|
|
|
|
_COL_MM_PREFIX: TableColumn = (
|
|
|
|
|
"MM Prefix",
|
|
|
|
|
lambda b: bool_to_emoji(b["supports_mm_prefix"]),
|
|
|
|
|
)
|
|
|
|
|
_COL_DCP: TableColumn = ("DCP", lambda b: bool_to_emoji(b["supports_dcp"]))
|
|
|
|
|
_COL_ATTN_TYPES: TableColumn = ("Attention Types", lambda b: b["attn_types"])
|
|
|
|
|
_COL_COMPUTE_CAP: TableColumn = ("Compute Cap.", lambda b: b["compute_capability"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_literal_quotes(value: str) -> str:
|
|
|
|
|
"""Add literal backticks around all comma-separated items in a string."""
|
|
|
|
|
items = [item.strip() for item in value.split(",")]
|
|
|
|
|
return ", ".join(f"`{item}`" for item in items)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bool_to_emoji(value: bool) -> str:
|
|
|
|
|
"""Convert a boolean to a checkmark or X emoji."""
|
|
|
|
|
return "✅" if value else "❌"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_columns(is_mla: bool, has_versions: bool) -> list[TableColumn]:
|
|
|
|
|
"""Build the column list for a backend feature table.
|
|
|
|
|
|
|
|
|
|
The column selection depends on whether it's an MLA table (includes
|
|
|
|
|
Sparse column) and whether any backend has version variants (includes
|
|
|
|
|
Version column).
|
|
|
|
|
"""
|
|
|
|
|
cols: list[TableColumn] = [_COL_BACKEND]
|
|
|
|
|
if has_versions:
|
|
|
|
|
cols.append(_COL_VERSION)
|
|
|
|
|
cols.extend([_COL_DTYPES, _COL_KV_DTYPES, _COL_BLOCK_SIZES, _COL_HEAD_SIZES])
|
|
|
|
|
cols.append(_COL_SINK)
|
|
|
|
|
if is_mla:
|
|
|
|
|
cols.append(_COL_SPARSE)
|
|
|
|
|
cols.extend([_COL_MM_PREFIX, _COL_DCP, _COL_ATTN_TYPES, _COL_COMPUTE_CAP])
|
|
|
|
|
return cols
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sort_key(x: dict[str, Any]) -> tuple[str, int]:
|
|
|
|
|
"""Sort key that keeps parent/child rows together in order."""
|
|
|
|
|
return (x.get("_sort_key", x["name"]), x.get("_sort_order", 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _render_table(
|
|
|
|
|
columns: list[TableColumn],
|
|
|
|
|
backends: list[dict[str, Any]],
|
|
|
|
|
) -> list[str]:
|
|
|
|
|
"""Render a markdown table from column specs and backend data."""
|
|
|
|
|
header = "| " + " | ".join(name for name, _ in columns) + " |"
|
|
|
|
|
sep = "|" + "|".join("-" * (len(name) + 2) for name, _ in columns) + "|"
|
|
|
|
|
lines = [header, sep]
|
|
|
|
|
for info in sorted(backends, key=_sort_key):
|
|
|
|
|
row = "| " + " | ".join(fmt(info) for _, fmt in columns) + " |"
|
|
|
|
|
lines.append(row)
|
|
|
|
|
return lines
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_markdown_table(
|
|
|
|
|
backends: list[dict[str, Any]], title: str, is_mla_table: bool = False
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Generate a titled markdown table from backend info."""
|
|
|
|
|
if not backends:
|
|
|
|
|
return f"## {title}\n\nNo backends found.\n"
|
|
|
|
|
has_versions = any(b.get("version") for b in backends)
|
|
|
|
|
columns = _build_columns(is_mla_table, has_versions)
|
|
|
|
|
lines = [f"## {title}", ""]
|
|
|
|
|
lines.extend(_render_table(columns, backends))
|
|
|
|
|
lines.append("")
|
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Markdown section generators (usage, priority, legend, MLA)
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
2026-01-28 17:20:22 -05:00
|
|
|
def generate_usage_section() -> str:
|
|
|
|
|
"""Generate the usage documentation section."""
|
|
|
|
|
return """## Setting the Attention Backend
|
|
|
|
|
|
|
|
|
|
### Command Line
|
|
|
|
|
|
|
|
|
|
There are two ways to specify the backend from the command line:
|
|
|
|
|
|
|
|
|
|
**Option 1: Using `--attention-backend` (simple)**
|
|
|
|
|
|
|
|
|
|
```bash
|
|
|
|
|
vllm serve <model> --attention-backend FLASH_ATTN
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
**Option 2: Using `--attention-config.backend` / `-ac.backend` (structured config)**
|
|
|
|
|
|
|
|
|
|
```bash
|
|
|
|
|
# Dot notation
|
|
|
|
|
vllm serve <model> --attention-config.backend FLASH_ATTN
|
|
|
|
|
vllm serve <model> -ac.backend FLASH_ATTN
|
|
|
|
|
|
|
|
|
|
# JSON format
|
|
|
|
|
vllm serve <model> --attention-config '{"backend": "FLASH_ATTN"}'
|
|
|
|
|
vllm serve <model> -ac '{"backend": "FLASH_ATTN"}'
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
> **Note:** `--attention-backend` and `--attention-config.backend` are mutually
|
|
|
|
|
> exclusive. Use one or the other, not both.
|
|
|
|
|
|
|
|
|
|
### Python API
|
|
|
|
|
|
|
|
|
|
Use `AttentionConfig` with the `LLM` class:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
from vllm import LLM
|
|
|
|
|
from vllm.config import AttentionConfig
|
|
|
|
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
|
|
|
|
|
|
|
|
|
# Method 1: Using AttentionConfig with enum
|
|
|
|
|
llm = LLM(
|
|
|
|
|
model="Qwen/Qwen3-0.6B",
|
|
|
|
|
attention_config=AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Method 2: Using attention_backend parameter with string
|
|
|
|
|
llm = LLM(
|
|
|
|
|
model="Qwen/Qwen3-0.6B",
|
|
|
|
|
attention_backend="FLASH_ATTN",
|
|
|
|
|
)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
## Backend Selection Behavior
|
|
|
|
|
|
|
|
|
|
### Manual Selection
|
|
|
|
|
|
|
|
|
|
When you explicitly set a backend via `--attention-backend` or `AttentionConfig`:
|
|
|
|
|
|
|
|
|
|
1. The backend is **validated** against your configuration (model dtype, head
|
|
|
|
|
size, compute capability, etc.)
|
|
|
|
|
2. If the backend **doesn't support** your configuration, an error is raised
|
|
|
|
|
with the specific reason
|
|
|
|
|
3. If valid, the backend is used
|
|
|
|
|
|
|
|
|
|
Example error when selecting an incompatible backend:
|
|
|
|
|
|
|
|
|
|
```text
|
|
|
|
|
ValueError: Selected backend FLASHMLA is not valid for this configuration.
|
|
|
|
|
Reason: ['compute capability not supported']
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
### Automatic Selection
|
|
|
|
|
|
|
|
|
|
When no backend is specified (the default):
|
|
|
|
|
|
|
|
|
|
1. vLLM iterates through backends in **priority order** (see tables below)
|
|
|
|
|
2. Each backend is validated against your configuration
|
|
|
|
|
3. The **first compatible backend** is selected
|
|
|
|
|
4. If no backend is compatible, an error is raised listing all backends and
|
|
|
|
|
their incompatibility reasons
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _priority_table(title: str, backends: list[str]) -> list[str]:
|
|
|
|
|
"""Generate a priority table for a list of backends."""
|
|
|
|
|
return [
|
|
|
|
|
f"**{title}:**",
|
|
|
|
|
"",
|
|
|
|
|
"| Priority | Backend |",
|
|
|
|
|
"|----------|---------|",
|
|
|
|
|
*[f"| {i} | `{b}` |" for i, b in enumerate(backends, 1)],
|
|
|
|
|
"",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_priority_section(priorities: dict[str, list[str]]) -> str:
|
|
|
|
|
"""Generate the priority ranking section."""
|
|
|
|
|
lines = [
|
|
|
|
|
"## Backend Priority (CUDA)",
|
|
|
|
|
"",
|
|
|
|
|
"When no backend is explicitly selected, vLLM chooses the first",
|
|
|
|
|
"compatible backend from these priority-ordered lists.",
|
|
|
|
|
"",
|
|
|
|
|
"Priority is **1 = highest** (tried first).",
|
|
|
|
|
"",
|
|
|
|
|
"### Standard Attention (MHA, MQA, GQA)",
|
|
|
|
|
"",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
sm100 = "Blackwell (SM 10.x)"
|
|
|
|
|
ampere = "Ampere/Hopper (SM 8.x-9.x)"
|
|
|
|
|
|
|
|
|
|
if "standard_sm100" in priorities:
|
|
|
|
|
lines.extend(_priority_table(sm100, priorities["standard_sm100"]))
|
|
|
|
|
if "standard_default" in priorities:
|
|
|
|
|
lines.extend(_priority_table(ampere, priorities["standard_default"]))
|
|
|
|
|
|
|
|
|
|
lines.extend(["### MLA Attention (DeepSeek-style)", ""])
|
|
|
|
|
|
|
|
|
|
if "mla_sm100" in priorities:
|
|
|
|
|
lines.extend(_priority_table(sm100, priorities["mla_sm100"]))
|
|
|
|
|
if "mla_default" in priorities:
|
|
|
|
|
lines.extend(_priority_table(ampere, priorities["mla_default"]))
|
|
|
|
|
|
|
|
|
|
lines.append(
|
|
|
|
|
"> **Note:** ROCm and CPU platforms have their own selection logic. "
|
|
|
|
|
"See the platform-specific documentation for details."
|
|
|
|
|
)
|
|
|
|
|
lines.append("")
|
|
|
|
|
|
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
def generate_legend() -> str:
|
|
|
|
|
"""Generate a legend explaining the table columns."""
|
|
|
|
|
return """## Legend
|
|
|
|
|
|
|
|
|
|
| Column | Description |
|
|
|
|
|
|--------|-------------|
|
|
|
|
|
| **Dtypes** | Supported model data types (fp16, bf16, fp32) |
|
|
|
|
|
| **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) |
|
|
|
|
|
| **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) |
|
|
|
|
|
| **Head Sizes** | Supported attention head sizes |
|
|
|
|
|
| **Sink** | Attention sink support (for StreamingLLM) |
|
|
|
|
|
| **Sparse** | Sparse attention support (MLA only) |
|
|
|
|
|
| **MM Prefix** | Multimodal prefix full attention support |
|
|
|
|
|
| **DCP** | Decode Context Parallelism support (`--decode-context-parallel-size`) |
|
|
|
|
|
| **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) |
|
|
|
|
|
| **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) |
|
|
|
|
|
|
|
|
|
|
**Symbols:** ✅ = Supported, ❌ = Not supported
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
2026-01-28 17:20:22 -05:00
|
|
|
def generate_mla_section(
|
|
|
|
|
prefill_backends: list[dict[str, Any]], decode_backends: list[dict[str, Any]]
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Generate the complete MLA section with prefill and decode tables."""
|
|
|
|
|
lines = [
|
|
|
|
|
"## MLA (Multi-head Latent Attention) Backends",
|
|
|
|
|
"",
|
|
|
|
|
"MLA uses separate backends for prefill and decode phases.",
|
|
|
|
|
"",
|
|
|
|
|
"### Prefill Backends",
|
|
|
|
|
"",
|
|
|
|
|
"The prefill backend is selected at runtime based on hardware and",
|
|
|
|
|
"configuration.",
|
|
|
|
|
"",
|
|
|
|
|
"| Backend | Description | Compute Cap. | Enable | Disable | Notes |",
|
|
|
|
|
"|---------|-------------|--------------|--------|---------|-------|",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for backend in prefill_backends:
|
|
|
|
|
row = "| {} | {} | {} | {} | {} | {} |".format(
|
|
|
|
|
backend["name"],
|
|
|
|
|
backend["description"],
|
|
|
|
|
backend["compute_capability"],
|
|
|
|
|
backend["enable"],
|
|
|
|
|
backend["disable"],
|
|
|
|
|
backend.get("notes", ""),
|
|
|
|
|
)
|
|
|
|
|
lines.append(row)
|
|
|
|
|
|
|
|
|
|
lines.extend(
|
|
|
|
|
[
|
|
|
|
|
"",
|
|
|
|
|
"> **‡** TRT-LLM Ragged is the default on Blackwell (SM100).",
|
|
|
|
|
"> On other GPUs, FlashAttention is used as the default.",
|
|
|
|
|
"",
|
|
|
|
|
"### Decode Backends",
|
|
|
|
|
"",
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# Reuse data-driven table rendering for decode backends
|
|
|
|
|
columns = _build_columns(is_mla=True, has_versions=False)
|
|
|
|
|
lines.extend(_render_table(columns, decode_backends))
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
lines.append("")
|
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
# Top-level orchestration
|
|
|
|
|
# ---------------------------------------------------------------------------
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_docs() -> str:
|
|
|
|
|
"""Generate the complete documentation."""
|
|
|
|
|
attention_backends_map = parse_registry()
|
|
|
|
|
|
|
|
|
|
# Parse priority lists from cuda.py
|
|
|
|
|
priorities = parse_cuda_priority_lists()
|
|
|
|
|
|
|
|
|
|
# Parse FlashAttention FA2/FA3 feature differences
|
|
|
|
|
fa_features = parse_flash_attn_features()
|
|
|
|
|
|
|
|
|
|
# Parse FlashInfer TRTLLM feature differences (native vs TRTLLM on Blackwell)
|
|
|
|
|
fi_features = parse_flashinfer_trtllm_features()
|
|
|
|
|
|
|
|
|
|
# Parse MLA prefill backends
|
|
|
|
|
mla_prefill_backends = parse_mla_prefill_backends()
|
|
|
|
|
|
|
|
|
|
# Collect backend info
|
|
|
|
|
all_backends = []
|
|
|
|
|
for backend_name, class_path in attention_backends_map.items():
|
2026-02-09 18:33:43 -05:00
|
|
|
if backend_name in SKIP_BACKENDS:
|
2026-01-28 17:20:22 -05:00
|
|
|
continue
|
|
|
|
|
info = analyze_backend(backend_name, class_path)
|
|
|
|
|
if info:
|
|
|
|
|
all_backends.append(info)
|
|
|
|
|
|
2026-02-09 18:33:43 -05:00
|
|
|
# Expand backends into version variants
|
2026-01-28 17:20:22 -05:00
|
|
|
if fa_features:
|
2026-02-09 18:33:43 -05:00
|
|
|
all_backends = _expand_flash_attn_variants(all_backends, fa_features)
|
2026-01-28 17:20:22 -05:00
|
|
|
if fi_features:
|
2026-02-09 18:33:43 -05:00
|
|
|
all_backends = _expand_flashinfer_variants(all_backends, fi_features)
|
2026-01-28 17:20:22 -05:00
|
|
|
|
|
|
|
|
# Split into MLA and non-MLA
|
|
|
|
|
mla_backends = [b for b in all_backends if b["is_mla"]]
|
|
|
|
|
non_mla_backends = [b for b in all_backends if not b["is_mla"]]
|
|
|
|
|
|
|
|
|
|
# Generate documentation
|
|
|
|
|
script_path = "tools/pre_commit/generate_attention_backend_docs.py"
|
|
|
|
|
doc_lines = [
|
|
|
|
|
"# Attention Backend Feature Support",
|
|
|
|
|
"",
|
|
|
|
|
f"This document is auto-generated by `{script_path}`.",
|
|
|
|
|
"It shows the feature support for each registered attention backend",
|
|
|
|
|
"based on the checks in `AttentionBackend.validate_configuration()`.",
|
|
|
|
|
"",
|
|
|
|
|
"**Do not edit this file manually.** Run the following command to",
|
|
|
|
|
"regenerate it:",
|
|
|
|
|
"",
|
|
|
|
|
"```bash",
|
|
|
|
|
f"python {script_path}",
|
|
|
|
|
"```",
|
|
|
|
|
"",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Add usage documentation
|
|
|
|
|
doc_lines.append(generate_usage_section())
|
|
|
|
|
|
|
|
|
|
# Add priority section
|
|
|
|
|
doc_lines.append(generate_priority_section(priorities))
|
|
|
|
|
|
|
|
|
|
# Add legend and feature tables
|
|
|
|
|
doc_lines.append(generate_legend())
|
|
|
|
|
standard_title = "Standard Attention (MHA, MQA, GQA) Backends"
|
|
|
|
|
doc_lines.append(
|
|
|
|
|
generate_markdown_table(non_mla_backends, standard_title, is_mla_table=False)
|
|
|
|
|
)
|
|
|
|
|
# Add footnotes for version/variant distinctions (in table order)
|
|
|
|
|
footnotes = []
|
|
|
|
|
if fi_features:
|
|
|
|
|
footnotes.append(
|
|
|
|
|
"> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which "
|
|
|
|
|
"supports sinks. Disable via `--attention-config.use_trtllm_attention=0`."
|
|
|
|
|
)
|
|
|
|
|
if fa_features:
|
|
|
|
|
footnotes.append(
|
|
|
|
|
"> **\\*** Specify the FlashAttention version via "
|
2026-03-01 18:44:57 -05:00
|
|
|
"`--attention-config.flash_attn_version=2`, `3`, or `4`. "
|
|
|
|
|
"Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), "
|
2026-01-28 17:20:22 -05:00
|
|
|
"FA2 otherwise."
|
|
|
|
|
)
|
|
|
|
|
if footnotes:
|
|
|
|
|
doc_lines.append("\n>\n".join(footnotes) + "\n")
|
|
|
|
|
|
|
|
|
|
# Add MLA section with prefill and decode backends
|
|
|
|
|
doc_lines.append(generate_mla_section(mla_prefill_backends, mla_backends))
|
|
|
|
|
|
|
|
|
|
return "\n".join(doc_lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
description="Generate attention backend documentation table"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--output",
|
|
|
|
|
"-o",
|
|
|
|
|
type=str,
|
|
|
|
|
default=str(REPO_ROOT / "docs" / "design" / "attention_backends.md"),
|
|
|
|
|
help="Output file path (default: docs/design/attention_backends.md)",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--check",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Check if the documentation is up to date (for pre-commit)",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"files",
|
|
|
|
|
nargs="*",
|
|
|
|
|
help="Files to check (passed by pre-commit). If none are relevant, skip.",
|
|
|
|
|
)
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
if args.files and not any(is_relevant_file(f) for f in args.files):
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
|
|
output_path = Path(args.output)
|
|
|
|
|
new_content = generate_docs()
|
|
|
|
|
|
|
|
|
|
if args.check:
|
|
|
|
|
needs_update = (
|
|
|
|
|
not output_path.exists() or output_path.read_text() != new_content
|
|
|
|
|
)
|
|
|
|
|
if needs_update:
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
output_path.write_text(new_content)
|
|
|
|
|
print(f"🔄 Regenerated: {output_path}")
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
print(f"✅ Up to date: {output_path}")
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
output_path.write_text(new_content)
|
|
|
|
|
print(f"Generated: {output_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|