diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 6e84dde92..b551e31db 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -152,6 +152,7 @@ Priority is **1 = highest** (tried first). | **Sink** | Attention sink support (for StreamingLLM) | | **Sparse** | Sparse attention support (MLA only) | | **MM Prefix** | Multimodal prefix full attention support | +| **DCP** | Decode Context Parallelism support (`--decode-context-parallel-size`) | | **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) | | **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) | @@ -159,20 +160,20 @@ Priority is **1 = highest** (tried first). ## Standard Attention (MHA, MQA, GQA) Backends -| Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | Attention Types | Compute Cap. | -|---------|---------|--------|-----------|-------------|------------|------|-----------|-----------------|--------------| -| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | All | N/A | -| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | Decoder | 7.x-9.x | -| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | Decoder | 10.x | -| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | All | ≥8.0 | -| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | All | 9.x | -| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | Decoder | Any | -| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | Decoder, Encoder Only | Any | -| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | Decoder | N/A | -| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | Decoder | N/A | -| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | Decoder | Any | -| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | All | Any | +| Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. | +|---------|---------|--------|-----------|-------------|------------|------|-----------|-----|-----------------|--------------| +| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | +| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x | +| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | +| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | +| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | +| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | +| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | +| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | N/A | +| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | +| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > @@ -199,14 +200,14 @@ configuration. ### Decode Backends -| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | Attention Types | Compute Cap. | -|---------|--------|-----------|-------------|------------|------|--------|-----------|-----------------|--------------| -| `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | Decoder | 10.x | -| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | Decoder | 10.x | -| `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | Decoder | 9.x-10.x | -| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | Decoder | 9.x-10.x | -| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | Decoder | 9.x | -| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A | -| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | Decoder | Any | +| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. | +|---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------| +| `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x | +| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | +| `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | +| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | +| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | +| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | +| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | diff --git a/tools/pre_commit/generate_attention_backend_docs.py b/tools/pre_commit/generate_attention_backend_docs.py index 3cca4959d..eb68deb1b 100644 --- a/tools/pre_commit/generate_attention_backend_docs.py +++ b/tools/pre_commit/generate_attention_backend_docs.py @@ -17,9 +17,14 @@ import argparse import ast import fnmatch import sys +from collections.abc import Callable from pathlib import Path from typing import Any +# --------------------------------------------------------------------------- +# Constants and file paths +# --------------------------------------------------------------------------- + REPO_ROOT = Path(__file__).parent.parent.parent RELEVANT_PATTERNS = [ @@ -32,6 +37,18 @@ RELEVANT_PATTERNS = [ "docs/design/attention_backends.md", ] +BACKENDS_DIR = REPO_ROOT / "vllm" / "v1" / "attention" / "backends" +REGISTRY_FILE = BACKENDS_DIR / "registry.py" +CUDA_PLATFORM_FILE = REPO_ROOT / "vllm" / "platforms" / "cuda.py" +FA_UTILS_FILE = BACKENDS_DIR / "fa_utils.py" +FLASHINFER_UTILS_FILE = REPO_ROOT / "vllm" / "utils" / "flashinfer.py" +MLA_ATTENTION_FILE = ( + REPO_ROOT / "vllm" / "model_executor" / "layers" / "attention" / "mla_attention.py" +) + +# Backends to skip during doc generation +SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"} + def is_relevant_file(filepath: str) -> bool: """Check if a file matches any of the relevant patterns.""" @@ -46,14 +63,197 @@ def is_relevant_file(filepath: str) -> bool: return any(fnmatch.fnmatch(path_str, pattern) for pattern in RELEVANT_PATTERNS) -BACKENDS_DIR = REPO_ROOT / "vllm" / "v1" / "attention" / "backends" -REGISTRY_FILE = BACKENDS_DIR / "registry.py" -CUDA_PLATFORM_FILE = REPO_ROOT / "vllm" / "platforms" / "cuda.py" -FA_UTILS_FILE = BACKENDS_DIR / "fa_utils.py" -FLASHINFER_UTILS_FILE = REPO_ROOT / "vllm" / "utils" / "flashinfer.py" -MLA_ATTENTION_FILE = ( - REPO_ROOT / "vllm" / "model_executor" / "layers" / "attention" / "mla_attention.py" -) +# --------------------------------------------------------------------------- +# AST utility helpers +# --------------------------------------------------------------------------- + + +def find_class_in_ast(tree: ast.AST, class_name: str) -> ast.ClassDef | None: + """Find a class definition in an AST.""" + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + return node + return None + + +def find_method(node: ast.ClassDef, method_name: str) -> ast.FunctionDef | None: + """Find a method in a class definition.""" + for item in node.body: + if isinstance(item, ast.FunctionDef) and item.name == method_name: + return item + return None + + +def method_returns_true(method: ast.FunctionDef | None) -> bool: + """Check if a method simply returns True.""" + if method is None: + return False + for node in ast.walk(method): + if ( + isinstance(node, ast.Return) + and isinstance(node.value, ast.Constant) + and node.value.value is True + ): + return True + return False + + +def check_method_overrides(node: ast.ClassDef, method_name: str) -> bool: + """Check if a method is overridden and returns True.""" + return method_returns_true(find_method(node, method_name)) + + +def _find_bool_class_var(class_node: ast.ClassDef, var_name: str) -> bool | None: + """Find a bool class variable in a class definition. Returns None if not found.""" + for item in class_node.body: + # Check for annotated assignment: attr: bool = True/False + if ( + isinstance(item, ast.AnnAssign) + and isinstance(item.target, ast.Name) + and item.target.id == var_name + and isinstance(item.value, ast.Constant) + and isinstance(item.value.value, bool) + ): + return item.value.value + # Check for plain assignment: attr = True/False + if isinstance(item, ast.Assign): + for target in item.targets: + if ( + isinstance(target, ast.Name) + and target.id == var_name + and isinstance(item.value, ast.Constant) + and isinstance(item.value.value, bool) + ): + return item.value.value + return None + + +def _parse_list_class_var(node: ast.ClassDef, var_name: str) -> list[str] | None: + """Parse a list-type class variable, returning None if not found.""" + for item in node.body: + if not isinstance(item, ast.AnnAssign): + continue + if not isinstance(item.target, ast.Name): + continue + if item.target.id != var_name: + continue + if not (item.value and isinstance(item.value, ast.List)): + continue + result = [] + for elt in item.value.elts: + if isinstance(elt, ast.Attribute): + result.append(elt.attr) + elif isinstance(elt, ast.Constant): + result.append(str(elt.value)) + return result + return None + + +def _parse_return_list( + method: ast.FunctionDef | None, handle_multiple_of: bool = False +) -> list[str]: + """Extract list items from a method's return statement.""" + if method is None: + return [] + for stmt in ast.walk(method): + if not isinstance(stmt, ast.Return): + continue + if not isinstance(stmt.value, ast.List): + continue + sizes = [] + for elt in stmt.value.elts: + if isinstance(elt, ast.Constant): + sizes.append(str(elt.value)) + elif ( + handle_multiple_of + and isinstance(elt, ast.Call) + and isinstance(elt.func, ast.Name) + and elt.func.id == "MultipleOf" + and elt.args + and isinstance(elt.args[0], ast.Constant) + ): + sizes.append(f"%{elt.args[0].value}") + if sizes: + return sizes + return [] + + +def _get_parent_class_name(class_node: ast.ClassDef) -> str | None: + """Get the first parent class name (simple name only). + + Handles both simple inheritance (class Foo(Bar)) and generic + inheritance (class Foo(Bar[T])). + """ + if not class_node.bases: + return None + base = class_node.bases[0] + if isinstance(base, ast.Name): + return base.id + if isinstance(base, ast.Subscript) and isinstance(base.value, ast.Name): + return base.value.id + return None + + +def _resolve_import_to_file( + tree: ast.AST, class_name: str, source_file: Path | None = None +) -> Path | None: + """Try to resolve a class name to its source file via imports in the AST. + + Handles both absolute imports (from vllm.foo import Bar) and relative + imports (from .foo import Bar) when source_file is provided. + """ + for node in ast.walk(tree): + if not isinstance(node, ast.ImportFrom): + continue + for alias in node.names: + actual_name = alias.asname or alias.name + if actual_name != class_name: + continue + if not node.module: + continue + + if node.level and node.level > 0 and source_file: + # Relative import: resolve from the source file's directory + base_dir = source_file.parent + for _ in range(node.level - 1): + base_dir = base_dir.parent + module_path = node.module.replace(".", "/") + py_file = base_dir / f"{module_path}.py" + else: + # Absolute import + module_path = node.module.replace(".", "/") + py_file = REPO_ROOT / f"{module_path}.py" + + if py_file.exists(): + return py_file + return None + + +def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None: + """Find a compute capability from is_device_capability_family() calls in a function. + + Looks for the pattern: current_platform.is_device_capability_family(N) + and converts N (e.g. 100) to a CC string (e.g. "10.x"). + """ + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef) or node.name != func_name: + continue + for n in ast.walk(node): + if ( + isinstance(n, ast.Call) + and isinstance(n.func, ast.Attribute) + and n.func.attr == "is_device_capability_family" + and n.args + and isinstance(n.args[0], ast.Constant) + and isinstance(n.args[0].value, int) + ): + return f"{n.args[0].value // 10}.x" + return None + + +# --------------------------------------------------------------------------- +# Registry and file resolution +# --------------------------------------------------------------------------- def parse_registry() -> dict[str, str]: @@ -88,309 +288,9 @@ def get_file_from_class_path(class_path: str) -> Path | None: return py_file if py_file.exists() else None -def parse_flash_attn_features() -> dict[str, dict[str, Any]]: - """Parse fa_utils.py to detect FA2 vs FA3 feature differences. - - Returns a dict with 'fa2' and 'fa3' keys containing their respective - feature overrides for compute capability, KV cache dtypes, and sink support. - """ - if not FA_UTILS_FILE.exists(): - return {} - - try: - tree = ast.parse(FA_UTILS_FILE.read_text()) - except Exception: - return {} - - # Analyze the functions to determine FA3-specific features - fa3_supports_fp8 = False - fa3_supports_sinks = False - fa3_compute_cap: str | None = None - - for node in ast.walk(tree): - if not isinstance(node, ast.FunctionDef): - continue - - # Check flash_attn_supports_fp8 - looks for `get_flash_attn_version() == 3` - if node.name == "flash_attn_supports_fp8": - for n in ast.walk(node): - if ( - isinstance(n, ast.Compare) - and isinstance(n.left, ast.Call) - and isinstance(n.left.func, ast.Name) - and n.left.func.id == "get_flash_attn_version" - ): - fa3_supports_fp8 = True - break - - # Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3` - if node.name == "flash_attn_supports_sinks": - for n in ast.walk(node): - if ( - isinstance(n, ast.Compare) - and isinstance(n.left, ast.Call) - and isinstance(n.left.func, ast.Name) - and n.left.func.id == "get_flash_attn_version" - ): - fa3_supports_sinks = True - break - - # Check get_flash_attn_version for FA3 compute capability - # Look for the ternary: 3 if (device_capability.major == 9 ...) else 2 - if node.name == "get_flash_attn_version": - for n in ast.walk(node): - # Look for IfExp (ternary) with `device_capability.major == 9` - if isinstance(n, ast.IfExp): - test = n.test - # Check if test is a BoolOp (and) containing the major check - if isinstance(test, ast.BoolOp): - for val in test.values: - if ( - isinstance(val, ast.Compare) - and isinstance(val.left, ast.Attribute) - and val.left.attr == "major" - and val.comparators - and isinstance(val.comparators[0], ast.Constant) - ): - fa3_compute_cap = f"{val.comparators[0].value}.x" - break - - return { - "fa2": { - "supports_fp8": False, - "supports_sink": False, - }, - "fa3": { - "compute_capability": fa3_compute_cap, - "supports_fp8": fa3_supports_fp8, - "supports_sink": fa3_supports_sinks, - }, - } - - -def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]: - """Parse flashinfer.py to detect TRTLLM-specific features. - - FLASHINFER uses TRTLLM attention on SM100 (Blackwell), which has different - capabilities (e.g., sink support) than native FlashInfer on earlier GPUs. - """ - if not FLASHINFER_UTILS_FILE.exists(): - return {} - - try: - tree = ast.parse(FLASHINFER_UTILS_FILE.read_text()) - except Exception: - return {} - - trtllm_compute_cap: str | None = None - - for node in ast.walk(tree): - if not isinstance(node, ast.FunctionDef): - continue - - # Parse supports_trtllm_attention for compute capability - # Look for: current_platform.is_device_capability_family(100) - if node.name == "supports_trtllm_attention": - for n in ast.walk(node): - if ( - isinstance(n, ast.Call) - and isinstance(n.func, ast.Attribute) - and n.func.attr == "is_device_capability_family" - and n.args - and isinstance(n.args[0], ast.Constant) - and isinstance(n.args[0].value, int) - ): - cap = n.args[0].value - # Convert 100 -> "10.x" - trtllm_compute_cap = f"{cap // 10}.x" - break - - if not trtllm_compute_cap: - return {} - - return { - "native": { - # Native FlashInfer: everything except SM100 - "supports_sink": False, - }, - "trtllm": { - # TRTLLM pathway on Blackwell - "compute_capability": trtllm_compute_cap, - "supports_sink": True, - }, - } - - -def parse_mla_prefill_backends() -> list[dict[str, Any]]: - """Parse MLA prefill backend options from mla_attention.py. - - MLA uses different backends for prefill vs decode. The decode backends are - registered in the registry, but prefill backends are selected at runtime - based on conditions in MLACommonImpl.__init__. - - Returns a list of prefill backend info dicts with their requirements. - """ - if not MLA_ATTENTION_FILE.exists(): - return [] - - try: - tree = ast.parse(MLA_ATTENTION_FILE.read_text()) - except Exception: - return [] - - # Find compute capability requirements by parsing use_* functions - flashinfer_cc: str | None = None - cudnn_cc: str | None = None - trtllm_cc: str | None = None - - for node in ast.walk(tree): - if not isinstance(node, ast.FunctionDef): - continue - - # Parse use_flashinfer_prefill for compute capability (SM100) - if node.name == "use_flashinfer_prefill": - for n in ast.walk(node): - if ( - isinstance(n, ast.Call) - and isinstance(n.func, ast.Attribute) - and n.func.attr == "is_device_capability_family" - and n.args - and isinstance(n.args[0], ast.Constant) - and isinstance(n.args[0].value, int) - ): - flashinfer_cc = f"{n.args[0].value // 10}.x" - - # Parse use_cudnn_prefill for compute capability (SM100) - if node.name == "use_cudnn_prefill": - for n in ast.walk(node): - if ( - isinstance(n, ast.Call) - and isinstance(n.func, ast.Attribute) - and n.func.attr == "is_device_capability_family" - and n.args - and isinstance(n.args[0], ast.Constant) - and isinstance(n.args[0].value, int) - ): - cudnn_cc = f"{n.args[0].value // 10}.x" - - # Parse use_trtllm_ragged_deepseek_prefill for compute capability - if node.name == "use_trtllm_ragged_deepseek_prefill": - for n in ast.walk(node): - if ( - isinstance(n, ast.Call) - and isinstance(n.func, ast.Attribute) - and n.func.attr == "is_device_capability_family" - and n.args - and isinstance(n.args[0], ast.Constant) - and isinstance(n.args[0].value, int) - ): - trtllm_cc = f"{n.args[0].value // 10}.x" - - # Build prefill backend list based on what we found - # Order matches the priority in MLACommonImpl.__init__ - prefill_backends: list[dict[str, Any]] = [] - - # TRT-LLM Ragged (highest priority if available) - if trtllm_cc: - prefill_backends.append( - { - "name": "TRT-LLM Ragged‡", - "description": "TensorRT-LLM ragged attention", - "compute_capability": trtllm_cc, - "enable": "Default on SM100", - "disable": "`-ac.use_trtllm_ragged_deepseek_prefill=0`", - "notes": "DeepSeek R1 dims only", - } - ) - - # FlashInfer prefill - if flashinfer_cc: - prefill_backends.append( - { - "name": "FlashInfer", - "description": "FlashInfer CUTLASS backend", - "compute_capability": flashinfer_cc, - "enable": "`-ac.disable_flashinfer_prefill=0`", - "disable": "`-ac.disable_flashinfer_prefill=1`", - "notes": "DeepSeek R1 dims only", - } - ) - - # cuDNN prefill - if cudnn_cc: - prefill_backends.append( - { - "name": "cuDNN", - "description": "cuDNN-based attention", - "compute_capability": cudnn_cc, - "enable": "`-ac.use_cudnn_prefill=1`", - "disable": "`-ac.use_cudnn_prefill=0`", - "notes": "", - } - ) - - # FlashAttention is always available as fallback - prefill_backends.append( - { - "name": "FlashAttention", - "description": "FlashAttention varlen (FA2/FA3)", - "compute_capability": "Any", - "enable": "Default fallback", - "disable": "Use other backends", - "notes": "FA3 on SM90, FA2 otherwise", - } - ) - - return prefill_backends - - -def find_class_in_ast(tree: ast.AST, class_name: str) -> ast.ClassDef | None: - """Find a class definition in an AST.""" - for node in ast.walk(tree): - if isinstance(node, ast.ClassDef) and node.name == class_name: - return node - return None - - -def find_method(node: ast.ClassDef, method_name: str) -> ast.FunctionDef | None: - """Find a method in a class definition.""" - for item in node.body: - if isinstance(item, ast.FunctionDef) and item.name == method_name: - return item - return None - - -def method_returns_true(method: ast.FunctionDef | None) -> bool: - """Check if a method simply returns True.""" - if method is None: - return False - for node in ast.walk(method): - if not isinstance(node, ast.Return): - continue - if isinstance(node.value, ast.Constant) and node.value.value is True: - return True - return False - - -def _parse_list_class_var(node: ast.ClassDef, var_name: str) -> list[str] | None: - """Parse a list-type class variable, returning None if not found.""" - for item in node.body: - if not isinstance(item, ast.AnnAssign): - continue - if not isinstance(item.target, ast.Name): - continue - if item.target.id != var_name: - continue - if not (item.value and isinstance(item.value, ast.List)): - continue - result = [] - for elt in item.value.elts: - if isinstance(elt, ast.Attribute): - result.append(elt.attr) - elif isinstance(elt, ast.Constant): - result.append(str(elt.value)) - return result - return None +# --------------------------------------------------------------------------- +# Backend feature extraction from AST +# --------------------------------------------------------------------------- def parse_supported_dtypes(node: ast.ClassDef) -> str: @@ -432,35 +332,6 @@ def parse_kv_cache_dtypes(node: ast.ClassDef) -> str: return "auto" -def _parse_return_list( - method: ast.FunctionDef | None, handle_multiple_of: bool = False -) -> list[str]: - """Extract list items from a method's return statement.""" - if method is None: - return [] - for stmt in ast.walk(method): - if not isinstance(stmt, ast.Return): - continue - if not isinstance(stmt.value, ast.List): - continue - sizes = [] - for elt in stmt.value.elts: - if isinstance(elt, ast.Constant): - sizes.append(str(elt.value)) - elif ( - handle_multiple_of - and isinstance(elt, ast.Call) - and isinstance(elt.func, ast.Name) - and elt.func.id == "MultipleOf" - and elt.args - and isinstance(elt.args[0], ast.Constant) - ): - sizes.append(f"%{elt.args[0].value}") - if sizes: - return sizes - return [] - - def parse_block_sizes(node: ast.ClassDef) -> str: """Parse get_supported_kernel_block_sizes method.""" method = find_method(node, "get_supported_kernel_block_sizes") @@ -573,10 +444,61 @@ def parse_attention_types(node: ast.ClassDef) -> str: return "All" if len(types) >= 3 else ", ".join(sorted(types)) -def check_method_overrides(node: ast.ClassDef, method_name: str) -> bool: - """Check if a method is overridden and returns True.""" - method = find_method(node, method_name) - return method_returns_true(method) +def parse_impl_bool_attr( + tree: ast.AST, + class_name: str, + attr_name: str, + default: bool = False, + source_file: Path | None = None, + _visited: set[str] | None = None, +) -> bool: + """Parse a boolean class attribute from an impl class, following inheritance. + + Walks up the inheritance chain within the same file and across files + (by resolving imports) to find the attribute value. + """ + if _visited is None: + _visited = set() + if class_name in _visited: + return default + _visited.add(class_name) + + class_node = find_class_in_ast(tree, class_name) + if class_node is None: + return default + + # Check directly on this class + value = _find_bool_class_var(class_node, attr_name) + if value is not None: + return value + + # Check parent class + parent_name = _get_parent_class_name(class_node) + if parent_name: + # Try parent in same file first + parent_node = find_class_in_ast(tree, parent_name) + if parent_node is not None: + return parse_impl_bool_attr( + tree, parent_name, attr_name, default, source_file, _visited + ) + + # Try resolving cross-file import + parent_file = _resolve_import_to_file(tree, parent_name, source_file) + if parent_file: + try: + parent_tree = ast.parse(parent_file.read_text()) + return parse_impl_bool_attr( + parent_tree, + parent_name, + attr_name, + default, + parent_file, + _visited, + ) + except Exception: + pass + + return default def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None: @@ -597,10 +519,7 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None return None # Check if this is an MLA backend by parent class or naming - parent = None - if class_node.bases: - base = class_node.bases[0] - parent = base.id if isinstance(base, ast.Name) else None + parent = _get_parent_class_name(class_node) mla_parents = {"MLACommonBackend", "FlashMLABackend", "FlashMLASparseBackend"} is_mla_backend = ( parent in mla_parents @@ -612,6 +531,21 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None is_non_cuda = backend_name.startswith(("CPU_", "ROCM_")) compute_cap = "N/A" if is_non_cuda else parse_compute_capability(class_node) + # Parse impl class features (DCP support) + impl_method = find_method(class_node, "get_impl_cls") + impl_class_name = None + if impl_method: + for stmt in ast.walk(impl_method): + if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name): + impl_class_name = stmt.value.id + break + + supports_dcp = False + if impl_class_name: + supports_dcp = parse_impl_bool_attr( + tree, impl_class_name, "can_return_lse_for_decode", False, file_path + ) + return { "name": backend_name, "dtypes": parse_supported_dtypes(class_node), @@ -624,114 +558,293 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None "supports_sink": check_method_overrides(class_node, "supports_sink"), "is_sparse": check_method_overrides(class_node, "is_sparse"), "supports_mm_prefix": check_method_overrides(class_node, "supports_mm_prefix"), + "supports_dcp": supports_dcp, } -def add_literal_quotes(value: str) -> str: - """Add literal backticks around all comma-separated items in a string.""" - items = [item.strip() for item in value.split(",")] - quoted_items = [f"`{item}`" for item in items] - return ", ".join(quoted_items) +# --------------------------------------------------------------------------- +# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill) +# --------------------------------------------------------------------------- -def bool_to_emoji(value: bool) -> str: - """Convert a boolean to a checkmark or X emoji.""" - return "✅" if value else "❌" +def parse_flash_attn_features() -> dict[str, dict[str, Any]]: + """Parse fa_utils.py to detect FA2 vs FA3 feature differences. - -def generate_markdown_table( - backends: list[dict[str, Any]], title: str, is_mla_table: bool = False -) -> str: - """Generate a markdown table from backend info. - - Args: - backends: List of backend info dictionaries. - title: Table title. - is_mla_table: If True, include MLA and Sparse columns (for MLA table). - If False, exclude them (for standard attention table). + Returns a dict with 'fa2' and 'fa3' keys containing their respective + feature overrides for compute capability, KV cache dtypes, and sink support. """ - if not backends: - return f"## {title}\n\nNo backends found.\n" + if not FA_UTILS_FILE.exists(): + return {} - # Check if any backend has a version (for FA2/FA3 split) - has_versions = any(b.get("version") for b in backends) + try: + tree = ast.parse(FA_UTILS_FILE.read_text()) + except Exception: + return {} - if is_mla_table: - header = ( - "| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes " - "| Sink | Sparse | MM Prefix | Attention Types | Compute Cap. |" - ) - separator = ( - "|---------|--------|-----------|-------------|------------" - "|------|--------|-----------|-----------------|--------------|" - ) - elif has_versions: - header = ( - "| Backend | Version | Dtypes | KV Dtypes | Block Sizes " - "| Head Sizes | Sink | MM Prefix | Attention Types | Compute Cap. |" - ) - separator = ( - "|---------|---------|--------|-----------|-------------" - "|------------|------|-----------|-----------------|--------------|" - ) - else: - header = ( - "| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes " - "| Sink | MM Prefix | Attention Types | Compute Cap. |" - ) - separator = ( - "|---------|--------|-----------|-------------|------------" - "|------|-----------|-----------------|--------------|" - ) - lines = [f"## {title}", "", header, separator] + # Analyze the functions to determine FA3-specific features + fa3_supports_fp8 = False + fa3_supports_sinks = False + fa3_compute_cap: str | None = None - def sort_key(x: dict[str, Any]) -> tuple[str, int]: - """Sort key that keeps parent/child rows together in order.""" - return (x.get("_sort_key", x["name"]), x.get("_sort_order", 0)) + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef): + continue - for info in sorted(backends, key=sort_key): - if is_mla_table: - row = "| `{}` | {} | {} | {} | {} | {} | {} | {} | {} | {} |".format( - info["name"], - info["dtypes"], - add_literal_quotes(info["kv_cache_dtypes"]), - info["block_sizes"], - info["head_sizes"], - bool_to_emoji(info["supports_sink"]), - bool_to_emoji(info["is_sparse"]), - bool_to_emoji(info["supports_mm_prefix"]), - info["attn_types"], - info["compute_capability"], - ) - elif has_versions: - row = "| `{}` | {} | {} | {} | {} | {} | {} | {} | {} | {} |".format( - info["name"], - info.get("version", ""), - info["dtypes"], - add_literal_quotes(info["kv_cache_dtypes"]), - info["block_sizes"], - info["head_sizes"], - bool_to_emoji(info["supports_sink"]), - bool_to_emoji(info["supports_mm_prefix"]), - info["attn_types"], - info["compute_capability"], - ) - else: - row = "| `{}` | {} | {} | {} | {} | {} | {} | {} | {} |".format( - info["name"], - info["dtypes"], - add_literal_quotes(info["kv_cache_dtypes"]), - info["block_sizes"], - info["head_sizes"], - bool_to_emoji(info["supports_sink"]), - bool_to_emoji(info["supports_mm_prefix"]), - info["attn_types"], - info["compute_capability"], - ) - lines.append(row) + # Check flash_attn_supports_fp8 - looks for `get_flash_attn_version() == 3` + if node.name == "flash_attn_supports_fp8": + for n in ast.walk(node): + if ( + isinstance(n, ast.Compare) + and isinstance(n.left, ast.Call) + and isinstance(n.left.func, ast.Name) + and n.left.func.id == "get_flash_attn_version" + ): + fa3_supports_fp8 = True + break - lines.append("") - return "\n".join(lines) + # Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3` + if node.name == "flash_attn_supports_sinks": + for n in ast.walk(node): + if ( + isinstance(n, ast.Compare) + and isinstance(n.left, ast.Call) + and isinstance(n.left.func, ast.Name) + and n.left.func.id == "get_flash_attn_version" + ): + fa3_supports_sinks = True + break + + # Check get_flash_attn_version for FA3 compute capability + # Look for the ternary: 3 if (device_capability.major == 9 ...) else 2 + if node.name == "get_flash_attn_version": + for n in ast.walk(node): + # Look for IfExp (ternary) with `device_capability.major == 9` + if isinstance(n, ast.IfExp): + test = n.test + # Check if test is a BoolOp (and) containing the major check + if isinstance(test, ast.BoolOp): + for val in test.values: + if ( + isinstance(val, ast.Compare) + and isinstance(val.left, ast.Attribute) + and val.left.attr == "major" + and val.comparators + and isinstance(val.comparators[0], ast.Constant) + ): + fa3_compute_cap = f"{val.comparators[0].value}.x" + break + + return { + "fa2": { + "supports_fp8": False, + "supports_sink": False, + }, + "fa3": { + "compute_capability": fa3_compute_cap, + "supports_fp8": fa3_supports_fp8, + "supports_sink": fa3_supports_sinks, + }, + } + + +def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]: + """Parse flashinfer.py to detect TRTLLM-specific features. + + FLASHINFER uses TRTLLM attention on SM100 (Blackwell), which has different + capabilities (e.g., sink support) than native FlashInfer on earlier GPUs. + """ + if not FLASHINFER_UTILS_FILE.exists(): + return {} + + try: + tree = ast.parse(FLASHINFER_UTILS_FILE.read_text()) + except Exception: + return {} + + trtllm_compute_cap = _find_cc_in_function(tree, "supports_trtllm_attention") + + if not trtllm_compute_cap: + return {} + + return { + "native": { + # Native FlashInfer: everything except SM100 + "supports_sink": False, + }, + "trtllm": { + # TRTLLM pathway on Blackwell + "compute_capability": trtllm_compute_cap, + "supports_sink": True, + }, + } + + +def parse_mla_prefill_backends() -> list[dict[str, Any]]: + """Parse MLA prefill backend options from mla_attention.py. + + MLA uses different backends for prefill vs decode. The decode backends are + registered in the registry, but prefill backends are selected at runtime + based on conditions in MLACommonImpl.__init__. + + Returns a list of prefill backend info dicts with their requirements. + """ + if not MLA_ATTENTION_FILE.exists(): + return [] + + try: + tree = ast.parse(MLA_ATTENTION_FILE.read_text()) + except Exception: + return [] + + # Find compute capability requirements by parsing use_* functions + trtllm_cc = _find_cc_in_function(tree, "use_trtllm_ragged_deepseek_prefill") + flashinfer_cc = _find_cc_in_function(tree, "use_flashinfer_prefill") + cudnn_cc = _find_cc_in_function(tree, "use_cudnn_prefill") + + # Build prefill backend list based on what we found + # Order matches the priority in MLACommonImpl.__init__ + prefill_backends: list[dict[str, Any]] = [] + + # TRT-LLM Ragged (highest priority if available) + if trtllm_cc: + prefill_backends.append( + { + "name": "TRT-LLM Ragged‡", + "description": "TensorRT-LLM ragged attention", + "compute_capability": trtllm_cc, + "enable": "Default on SM100", + "disable": "`-ac.use_trtllm_ragged_deepseek_prefill=0`", + "notes": "DeepSeek R1 dims only", + } + ) + + # FlashInfer prefill + if flashinfer_cc: + prefill_backends.append( + { + "name": "FlashInfer", + "description": "FlashInfer CUTLASS backend", + "compute_capability": flashinfer_cc, + "enable": "`-ac.disable_flashinfer_prefill=0`", + "disable": "`-ac.disable_flashinfer_prefill=1`", + "notes": "DeepSeek R1 dims only", + } + ) + + # cuDNN prefill + if cudnn_cc: + prefill_backends.append( + { + "name": "cuDNN", + "description": "cuDNN-based attention", + "compute_capability": cudnn_cc, + "enable": "`-ac.use_cudnn_prefill=1`", + "disable": "`-ac.use_cudnn_prefill=0`", + "notes": "", + } + ) + + # FlashAttention is always available as fallback + prefill_backends.append( + { + "name": "FlashAttention", + "description": "FlashAttention varlen (FA2/FA3)", + "compute_capability": "Any", + "enable": "Default fallback", + "disable": "Use other backends", + "notes": "FA3 on SM90, FA2 otherwise", + } + ) + + return prefill_backends + + +# --------------------------------------------------------------------------- +# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM) +# --------------------------------------------------------------------------- + + +def _expand_flash_attn_variants( + all_backends: list[dict[str, Any]], + fa_features: dict[str, dict[str, Any]], +) -> list[dict[str, Any]]: + """Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities.""" + expanded = [] + for backend in all_backends: + if backend["name"] != "FLASH_ATTN": + backend.setdefault("_sort_key", backend["name"]) + backend.setdefault("_sort_order", 0) + backend.setdefault("version", "") + expanded.append(backend) + continue + + # Create FA2 entry (keeps base backend's compute_capability) + fa2 = backend.copy() + fa2["version"] = "FA2*" + fa2["_sort_key"] = "FLASH_ATTN" + fa2["_sort_order"] = 0 + fa2["supports_sink"] = fa_features["fa2"]["supports_sink"] + + # Create FA3 entry (uses parsed compute_capability from fa_utils) + fa3 = backend.copy() + fa3["version"] = "FA3*" + fa3["_sort_key"] = "FLASH_ATTN" + fa3["_sort_order"] = 1 + if fa_features["fa3"]["compute_capability"]: + fa3["compute_capability"] = fa_features["fa3"]["compute_capability"] + fa3["supports_sink"] = fa_features["fa3"]["supports_sink"] + if fa_features["fa3"]["supports_fp8"]: + base_dtypes = backend["kv_cache_dtypes"].split(", ") + fp8_dtypes = ["fp8", "fp8_e4m3", "fp8_e5m2"] + new_dtypes = [d for d in fp8_dtypes if d not in base_dtypes] + fa3["kv_cache_dtypes"] = ", ".join(base_dtypes + new_dtypes) + + expanded.append(fa2) + expanded.append(fa3) + return expanded + + +def _expand_flashinfer_variants( + all_backends: list[dict[str, Any]], + fi_features: dict[str, dict[str, Any]], +) -> list[dict[str, Any]]: + """Expand FLASHINFER into native and TRTLLM variants.""" + expanded = [] + for backend in all_backends: + if backend["name"] != "FLASHINFER": + expanded.append(backend) + continue + + # Parse original compute capability to get min CC + orig_cap = backend["compute_capability"] + parts = orig_cap.replace(".x", "").split("-") + min_cc = parts[0] if parts else "7" + trtllm_cc = fi_features["trtllm"]["compute_capability"] + + # Create native entry (pre-Blackwell GPUs) + native = backend.copy() + native["version"] = "Native†" + native["_sort_key"] = "FLASHINFER" + native["_sort_order"] = 0 + native["supports_sink"] = fi_features["native"]["supports_sink"] + native["compute_capability"] = f"{min_cc}.x-9.x" + + # Create TRTLLM entry + trtllm = backend.copy() + trtllm["version"] = "TRTLLM†" + trtllm["_sort_key"] = "FLASHINFER" + trtllm["_sort_order"] = 1 + trtllm["compute_capability"] = trtllm_cc + trtllm["supports_sink"] = fi_features["trtllm"]["supports_sink"] + + expanded.append(native) + expanded.append(trtllm) + return expanded + + +# --------------------------------------------------------------------------- +# CUDA priority list parsing +# --------------------------------------------------------------------------- def parse_cuda_priority_lists() -> dict[str, list[str]]: @@ -827,6 +940,105 @@ def _extract_priorities(body: list, priorities: dict[str, list[str]], prefix: st priorities[f"{prefix}_default"] = backends +# --------------------------------------------------------------------------- +# Data-driven table rendering +# +# Each column is a (header, formatter) pair. The formatter takes a backend +# info dict and returns the cell string. Tables are assembled by selecting +# which columns to include, then calling _render_table(). +# --------------------------------------------------------------------------- + +# Column type alias for readability +TableColumn = tuple[str, Callable[[dict[str, Any]], str]] + +# Shared column definitions -- order here matches the output table order +_COL_BACKEND: TableColumn = ("Backend", lambda b: f"`{b['name']}`") +_COL_VERSION: TableColumn = ("Version", lambda b: b.get("version", "")) +_COL_DTYPES: TableColumn = ("Dtypes", lambda b: b["dtypes"]) +_COL_KV_DTYPES: TableColumn = ( + "KV Dtypes", + lambda b: add_literal_quotes(b["kv_cache_dtypes"]), +) +_COL_BLOCK_SIZES: TableColumn = ("Block Sizes", lambda b: b["block_sizes"]) +_COL_HEAD_SIZES: TableColumn = ("Head Sizes", lambda b: b["head_sizes"]) +_COL_SINK: TableColumn = ("Sink", lambda b: bool_to_emoji(b["supports_sink"])) +_COL_SPARSE: TableColumn = ("Sparse", lambda b: bool_to_emoji(b["is_sparse"])) +_COL_MM_PREFIX: TableColumn = ( + "MM Prefix", + lambda b: bool_to_emoji(b["supports_mm_prefix"]), +) +_COL_DCP: TableColumn = ("DCP", lambda b: bool_to_emoji(b["supports_dcp"])) +_COL_ATTN_TYPES: TableColumn = ("Attention Types", lambda b: b["attn_types"]) +_COL_COMPUTE_CAP: TableColumn = ("Compute Cap.", lambda b: b["compute_capability"]) + + +def add_literal_quotes(value: str) -> str: + """Add literal backticks around all comma-separated items in a string.""" + items = [item.strip() for item in value.split(",")] + return ", ".join(f"`{item}`" for item in items) + + +def bool_to_emoji(value: bool) -> str: + """Convert a boolean to a checkmark or X emoji.""" + return "✅" if value else "❌" + + +def _build_columns(is_mla: bool, has_versions: bool) -> list[TableColumn]: + """Build the column list for a backend feature table. + + The column selection depends on whether it's an MLA table (includes + Sparse column) and whether any backend has version variants (includes + Version column). + """ + cols: list[TableColumn] = [_COL_BACKEND] + if has_versions: + cols.append(_COL_VERSION) + cols.extend([_COL_DTYPES, _COL_KV_DTYPES, _COL_BLOCK_SIZES, _COL_HEAD_SIZES]) + cols.append(_COL_SINK) + if is_mla: + cols.append(_COL_SPARSE) + cols.extend([_COL_MM_PREFIX, _COL_DCP, _COL_ATTN_TYPES, _COL_COMPUTE_CAP]) + return cols + + +def _sort_key(x: dict[str, Any]) -> tuple[str, int]: + """Sort key that keeps parent/child rows together in order.""" + return (x.get("_sort_key", x["name"]), x.get("_sort_order", 0)) + + +def _render_table( + columns: list[TableColumn], + backends: list[dict[str, Any]], +) -> list[str]: + """Render a markdown table from column specs and backend data.""" + header = "| " + " | ".join(name for name, _ in columns) + " |" + sep = "|" + "|".join("-" * (len(name) + 2) for name, _ in columns) + "|" + lines = [header, sep] + for info in sorted(backends, key=_sort_key): + row = "| " + " | ".join(fmt(info) for _, fmt in columns) + " |" + lines.append(row) + return lines + + +def generate_markdown_table( + backends: list[dict[str, Any]], title: str, is_mla_table: bool = False +) -> str: + """Generate a titled markdown table from backend info.""" + if not backends: + return f"## {title}\n\nNo backends found.\n" + has_versions = any(b.get("version") for b in backends) + columns = _build_columns(is_mla_table, has_versions) + lines = [f"## {title}", ""] + lines.extend(_render_table(columns, backends)) + lines.append("") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Markdown section generators (usage, priority, legend, MLA) +# --------------------------------------------------------------------------- + + def generate_usage_section() -> str: """Generate the usage documentation section.""" return """## Setting the Attention Backend @@ -959,6 +1171,27 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str: return "\n".join(lines) +def generate_legend() -> str: + """Generate a legend explaining the table columns.""" + return """## Legend + +| Column | Description | +|--------|-------------| +| **Dtypes** | Supported model data types (fp16, bf16, fp32) | +| **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) | +| **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) | +| **Head Sizes** | Supported attention head sizes | +| **Sink** | Attention sink support (for StreamingLLM) | +| **Sparse** | Sparse attention support (MLA only) | +| **MM Prefix** | Multimodal prefix full attention support | +| **DCP** | Decode Context Parallelism support (`--decode-context-parallel-size`) | +| **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) | +| **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) | + +**Symbols:** ✅ = Supported, ❌ = Not supported +""" + + def generate_mla_section( prefill_backends: list[dict[str, Any]], decode_backends: list[dict[str, Any]] ) -> str: @@ -999,57 +1232,17 @@ def generate_mla_section( ] ) - # Generate decode backends table - header = ( - "| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes " - "| Sink | Sparse | MM Prefix | Attention Types | Compute Cap. |" - ) - separator = ( - "|---------|--------|-----------|-------------|------------" - "|------|--------|-----------|-----------------|--------------|" - ) - lines.extend([header, separator]) - - def sort_key(x: dict[str, Any]) -> tuple[str, int]: - return (x.get("_sort_key", x["name"]), x.get("_sort_order", 0)) - - for info in sorted(decode_backends, key=sort_key): - row = "| `{}` | {} | {} | {} | {} | {} | {} | {} | {} | {} |".format( - info["name"], - info["dtypes"], - add_literal_quotes(info["kv_cache_dtypes"]), - info["block_sizes"], - info["head_sizes"], - bool_to_emoji(info["supports_sink"]), - bool_to_emoji(info["is_sparse"]), - bool_to_emoji(info["supports_mm_prefix"]), - info["attn_types"], - info["compute_capability"], - ) - lines.append(row) + # Reuse data-driven table rendering for decode backends + columns = _build_columns(is_mla=True, has_versions=False) + lines.extend(_render_table(columns, decode_backends)) lines.append("") return "\n".join(lines) -def generate_legend() -> str: - """Generate a legend explaining the table columns.""" - return """## Legend - -| Column | Description | -|--------|-------------| -| **Dtypes** | Supported model data types (fp16, bf16, fp32) | -| **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) | -| **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) | -| **Head Sizes** | Supported attention head sizes | -| **Sink** | Attention sink support (for StreamingLLM) | -| **Sparse** | Sparse attention support (MLA only) | -| **MM Prefix** | Multimodal prefix full attention support | -| **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) | -| **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) | - -**Symbols:** ✅ = Supported, ❌ = Not supported -""" +# --------------------------------------------------------------------------- +# Top-level orchestration +# --------------------------------------------------------------------------- def generate_docs() -> str: @@ -1071,86 +1264,17 @@ def generate_docs() -> str: # Collect backend info all_backends = [] for backend_name, class_path in attention_backends_map.items(): - if backend_name in ("CUSTOM", "TORCH_SDPA"): + if backend_name in SKIP_BACKENDS: continue info = analyze_backend(backend_name, class_path) if info: all_backends.append(info) - # Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities + # Expand backends into version variants if fa_features: - expanded_backends = [] - for backend in all_backends: - if backend["name"] == "FLASH_ATTN": - # Create FA2 entry (keeps base backend's compute_capability) - fa2 = backend.copy() - fa2["name"] = "FLASH_ATTN" - fa2["version"] = "FA2*" - fa2["_sort_key"] = "FLASH_ATTN" - fa2["_sort_order"] = 0 - fa2["supports_sink"] = fa_features["fa2"]["supports_sink"] - - # Create FA3 entry (uses parsed compute_capability from fa_utils) - fa3 = backend.copy() - fa3["name"] = "FLASH_ATTN" - fa3["version"] = "FA3*" - fa3["_sort_key"] = "FLASH_ATTN" - fa3["_sort_order"] = 1 - if fa_features["fa3"]["compute_capability"]: - fa3["compute_capability"] = fa_features["fa3"]["compute_capability"] - fa3["supports_sink"] = fa_features["fa3"]["supports_sink"] - if fa_features["fa3"]["supports_fp8"]: - # Add fp8 dtypes to the base backend's kv_cache_dtypes - base_dtypes = backend["kv_cache_dtypes"].split(", ") - fp8_dtypes = ["fp8", "fp8_e4m3", "fp8_e5m2"] - new_dtypes = [d for d in fp8_dtypes if d not in base_dtypes] - fa3["kv_cache_dtypes"] = ", ".join(base_dtypes + new_dtypes) - - # Add FA2 first, then FA3 - expanded_backends.append(fa2) - expanded_backends.append(fa3) - else: - backend["_sort_key"] = backend["name"] - backend["_sort_order"] = 0 - backend["version"] = "" # No version for other backends - expanded_backends.append(backend) - all_backends = expanded_backends - - # Expand FLASHINFER into native and TRTLLM variants + all_backends = _expand_flash_attn_variants(all_backends, fa_features) if fi_features: - expanded_backends = [] - for backend in all_backends: - if backend["name"] == "FLASHINFER": - # Parse original compute capability to get min CC - orig_cap = backend["compute_capability"] - parts = orig_cap.replace(".x", "").split("-") - min_cc = parts[0] if parts else "7" - trtllm_cc = fi_features["trtllm"]["compute_capability"] - - # Create native entry (pre-Blackwell GPUs) - native = backend.copy() - native["name"] = "FLASHINFER" - native["version"] = "Native†" - native["_sort_key"] = "FLASHINFER" - native["_sort_order"] = 0 - native["supports_sink"] = fi_features["native"]["supports_sink"] - # Native FlashInfer is used on GPUs before SM100 (Blackwell) - native["compute_capability"] = f"{min_cc}.x-9.x" - - # Create TRTLLM entry - trtllm = backend.copy() - trtllm["name"] = "FLASHINFER" - trtllm["version"] = "TRTLLM†" - trtllm["_sort_key"] = "FLASHINFER" - trtllm["_sort_order"] = 1 - trtllm["compute_capability"] = trtllm_cc - trtllm["supports_sink"] = fi_features["trtllm"]["supports_sink"] - - expanded_backends.append(native) - expanded_backends.append(trtllm) - else: - expanded_backends.append(backend) - all_backends = expanded_backends + all_backends = _expand_flashinfer_variants(all_backends, fi_features) # Split into MLA and non-MLA mla_backends = [b for b in all_backends if b["is_mla"]]