diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 47ac91464..c748aa9e0 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -167,7 +167,7 @@ Priority is **1 = highest** (tried first). | ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ | | `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | | `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x | -| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | +| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.0 | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | | `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 | diff --git a/tools/pre_commit/generate_attention_backend_docs.py b/tools/pre_commit/generate_attention_backend_docs.py index bbbf4f4b6..9e14f8739 100644 --- a/tools/pre_commit/generate_attention_backend_docs.py +++ b/tools/pre_commit/generate_attention_backend_docs.py @@ -235,10 +235,11 @@ def _resolve_import_to_file( def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None: - """Find a compute capability from is_device_capability_family() calls in a function. + """Find a compute capability from is_device_capability*() calls in a function. - Looks for the pattern: current_platform.is_device_capability_family(N) - and converts N (e.g. 100) to a CC string (e.g. "10.x"). + Handles two patterns: + - is_device_capability_family(N): "M.x" (e.g. 100 -> "10.x") + - is_device_capability(N): "M.m" (e.g. 100 -> "10.0") """ for node in ast.walk(tree): if not isinstance(node, ast.FunctionDef) or node.name != func_name: @@ -247,12 +248,15 @@ def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None: if ( isinstance(n, ast.Call) and isinstance(n.func, ast.Attribute) - and n.func.attr == "is_device_capability_family" and n.args and isinstance(n.args[0], ast.Constant) and isinstance(n.args[0].value, int) ): - return f"{n.args[0].value // 10}.x" + val = n.args[0].value + if n.func.attr == "is_device_capability_family": + return f"{val // 10}.x" + elif n.func.attr == "is_device_capability": + return f"{val // 10}.{val % 10}" return None diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 065a9ca89..9f4921d25 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -289,10 +289,10 @@ def supports_trtllm_attention() -> bool: if envs.VLLM_BATCH_INVARIANT: return False - # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - return ( - current_platform.is_device_capability_family(100) and has_nvidia_artifactory() - ) + # TRTLLM attention is currently only validated on SM100 (CC 10.0). + # SM103 (GB300) hangs with FlashInfer >= 0.6.7. + # See: https://github.com/flashinfer-ai/flashinfer/issues/2939 + return current_platform.is_device_capability(100) and has_nvidia_artifactory() def force_use_trtllm_attention() -> bool | None: