diff --git a/.buildkite/test_areas/misc.yaml b/.buildkite/test_areas/misc.yaml index 69390cd6d..d8957c217 100644 --- a/.buildkite/test_areas/misc.yaml +++ b/.buildkite/test_areas/misc.yaml @@ -9,6 +9,7 @@ steps: - tests/v1 commands: - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt + - export VLLM_WORKER_MULTIPROC_METHOD=spawn # split the test to avoid interference - pytest -v -s -m 'not cpu_test' v1/core - pytest -v -s v1/executor diff --git a/.gitignore b/.gitignore index 8e864d090..795071bd7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* +!vllm/vllm_flash_attn/__init__.py +!vllm/vllm_flash_attn/flash_attn_interface.py # OpenAI triton kernels copied from source vllm/third_party/triton_kernels/* diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 41c4e308d..c206b9c39 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -17,7 +17,8 @@ endif() # They should be identical but if they aren't, this is a massive footgun. # # The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. -# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). +# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3), +# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files). # If no component is specified, vllm-flash-attn is still installed. # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. @@ -38,7 +39,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6 + GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn @@ -46,38 +47,62 @@ else() endif() -# Ensure the vllm/vllm_flash_attn directory exists before installation -install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS) - -# Make sure vllm-flash-attn install rules are nested under vllm/ -# This is here to support installing all components under the same prefix with cmake --install. -# setup.py installs every component separately but uses the same prefix for all. -# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3, -# and these statements don't hurt when installing neither component. -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS) -install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) -install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS) +# Install rules for FA components need the install prefix nested under vllm/ +# These run at install time, before the FA library's own install rules +foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C) + install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT ${_FA_COMPONENT}) + install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT}) + install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT ${_FA_COMPONENT}) +endforeach() # Fetch the vllm-flash-attn library FetchContent_MakeAvailable(vllm-flash-attn) message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") -# Restore the install prefix -install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) +# Restore the install prefix after FA's install rules +foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C) + install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT}) + install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT ${_FA_COMPONENT}) +endforeach() -# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in -# case only one is built, in the case both are built redundant work is done) -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm/vllm_flash_attn - COMPONENT _vllm_fa2_C - FILES_MATCHING PATTERN "*.py" -) +# Install shared Python files for both FA2 and FA3 components +foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C) + # Ensure the vllm/vllm_flash_attn directory exists before installation + install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" + COMPONENT ${_FA_COMPONENT}) -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm/vllm_flash_attn - COMPONENT _vllm_fa3_C - FILES_MATCHING PATTERN "*.py" -) + # Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py + # which are source-controlled in vllm) + install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm/vllm_flash_attn + COMPONENT ${_FA_COMPONENT} + FILES_MATCHING PATTERN "*.py" + PATTERN "__init__.py" EXCLUDE + PATTERN "flash_attn_interface.py" EXCLUDE + ) + +endforeach() + +# +# FA4 CuteDSL component +# This is a Python-only component that copies the flash_attn/cute directory +# and transforms imports to match our package structure. +# +add_custom_target(_vllm_fa4_cutedsl_C) + +# Copy flash_attn/cute directory (needed for FA4) and transform imports +# The cute directory uses flash_attn.cute imports internally, which we replace +# with vllm.vllm_flash_attn.cute to match our package structure. +install(CODE " + file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\") + foreach(SRC_FILE \${CUTE_PY_FILES}) + file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE}) + set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\") + get_filename_component(DST_DIR \${DST_FILE} DIRECTORY) + file(MAKE_DIRECTORY \${DST_DIR}) + file(READ \${SRC_FILE} FILE_CONTENTS) + string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\") + file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\") + endforeach() +" COMPONENT _vllm_fa4_cutedsl_C) diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 6d5c007e3..e726d9925 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -168,6 +168,7 @@ Priority is **1 = highest** (tried first). | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | +| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 | | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | @@ -178,7 +179,7 @@ Priority is **1 = highest** (tried first). > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > -> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, FA2 otherwise. +> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2`, `3`, or `4`. Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), FA2 otherwise. ## MLA (Multi-head Latent Attention) Backends diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 84fe34730..22477dc82 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -11,3 +11,7 @@ torchaudio==2.10.0 torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.6.4 + +# QuACK and Cutlass DSL for FA4 (cute-DSL implementation) +nvidia-cutlass-dsl>=4.4.0.dev1 +quack-kernels>=0.2.7 diff --git a/setup.py b/setup.py index afdd4b19b..556a511a3 100644 --- a/setup.py +++ b/setup.py @@ -976,6 +976,11 @@ if _is_cuda(): ): # FA3 requires CUDA 12.3 or later ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) + # FA4 CuteDSL - Python-only component for FA4's cute DSL support + # Optional since this doesn't produce a .so file, just copies Python files + ext_modules.append( + CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa4_cutedsl_C", optional=True) + ) if envs.VLLM_USE_PRECOMPILED or ( CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9") ): diff --git a/tools/pre_commit/generate_attention_backend_docs.py b/tools/pre_commit/generate_attention_backend_docs.py index 3aca49f94..628656f0d 100644 --- a/tools/pre_commit/generate_attention_backend_docs.py +++ b/tools/pre_commit/generate_attention_backend_docs.py @@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None # --------------------------------------------------------------------------- -# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill) +# Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill) # --------------------------------------------------------------------------- -def parse_flash_attn_features() -> dict[str, dict[str, Any]]: - """Parse fa_utils.py to detect FA2 vs FA3 feature differences. +def _parse_fa4_supported_caps() -> str | None: + """Parse flash_attn_interface.py for FA4 supported compute capabilities. - Returns a dict with 'fa2' and 'fa3' keys containing their respective + Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported(). + """ + fa_interface_file = ( + REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py" + ) + if not fa_interface_file.exists(): + return None + + try: + tree = ast.parse(fa_interface_file.read_text()) + except Exception: + return None + + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported": + continue + for n in ast.walk(node): + if not ( + isinstance(n, ast.Compare) + and len(n.ops) == 1 + and isinstance(n.ops[0], ast.NotIn) + and isinstance(n.comparators[0], ast.List) + ): + continue + caps: list[int] = [ + e.value + for e in n.comparators[0].elts + if isinstance(e, ast.Constant) and isinstance(e.value, int) + ] + if caps: + caps.sort() + return f"{caps[0]}.x-{caps[-1]}.x" + + return None + + +def parse_flash_attn_features() -> dict[str, dict[str, Any]]: + """Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences. + + Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective feature overrides for compute capability, KV cache dtypes, and sink support. """ if not FA_UTILS_FILE.exists(): @@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: fa3_supports_fp8 = False fa3_supports_sinks = False fa3_compute_cap: str | None = None + fa4_compute_cap: str | None = None for node in ast.walk(tree): if not isinstance(node, ast.FunctionDef): @@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: fa3_supports_sinks = True break - # Check get_flash_attn_version for FA3 compute capability - # Look for the ternary: 3 if (device_capability.major == 9 ...) else 2 + # Check get_flash_attn_version for FA3/FA4 compute capability if node.name == "get_flash_attn_version": for n in ast.walk(node): - # Look for IfExp (ternary) with `device_capability.major == 9` + # Handle IfExp (ternary) with `device_capability.major == 9` if isinstance(n, ast.IfExp): test = n.test - # Check if test is a BoolOp (and) containing the major check if isinstance(test, ast.BoolOp): for val in test.values: if ( @@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: fa3_compute_cap = f"{val.comparators[0].value}.x" break + # Handle If statements for FA3/FA4 detection + # e.g. `if device_capability.major == 9` -> FA3 + # `elif device_capability.major >= 10` -> FA4 + if isinstance(n, ast.If): + test = n.test + comparisons = ( + [v for v in test.values if isinstance(v, ast.Compare)] + if isinstance(test, ast.BoolOp) + else [test] + if isinstance(test, ast.Compare) + else [] + ) + for comp in comparisons: + if not ( + isinstance(comp.left, ast.Attribute) + and comp.left.attr == "major" + and comp.comparators + and isinstance(comp.comparators[0], ast.Constant) + and isinstance(comp.comparators[0].value, int) + ): + continue + op = comp.ops[0] + val = comp.comparators[0].value + if isinstance(op, ast.Eq) and fa3_compute_cap is None: + fa3_compute_cap = f"{val}.x" + elif isinstance(op, ast.GtE) and fa4_compute_cap is None: + fa4_compute_cap = f"≥{val}.0" + + # Fallback: try to parse FA4 compute caps from flash_attn_interface.py + if fa4_compute_cap is None: + fa4_compute_cap = _parse_fa4_supported_caps() + return { "fa2": { "supports_fp8": False, @@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: "supports_fp8": fa3_supports_fp8, "supports_sink": fa3_supports_sinks, }, + "fa4": { + "compute_capability": fa4_compute_cap, + "supports_fp8": False, + "supports_sink": False, + }, } @@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]: # --------------------------------------------------------------------------- -# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM) +# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM) # --------------------------------------------------------------------------- @@ -768,7 +843,7 @@ def _expand_flash_attn_variants( all_backends: list[dict[str, Any]], fa_features: dict[str, dict[str, Any]], ) -> list[dict[str, Any]]: - """Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities.""" + """Expand FLASH_ATTN into FA2, FA3, and FA4 variants.""" expanded = [] for backend in all_backends: if backend["name"] != "FLASH_ATTN": @@ -801,6 +876,18 @@ def _expand_flash_attn_variants( expanded.append(fa2) expanded.append(fa3) + + # Create FA4 entry if FA4 features are available + if "fa4" in fa_features: + fa4 = backend.copy() + fa4["version"] = "FA4*" + fa4["_sort_key"] = "FLASH_ATTN" + fa4["_sort_order"] = 2 + if fa_features["fa4"].get("compute_capability"): + fa4["compute_capability"] = fa_features["fa4"]["compute_capability"] + fa4["supports_sink"] = fa_features["fa4"]["supports_sink"] + expanded.append(fa4) + return expanded @@ -1360,7 +1447,8 @@ def generate_docs() -> str: if fa_features: footnotes.append( "> **\\*** Specify the FlashAttention version via " - "`--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, " + "`--attention-config.flash_attn_version=2`, `3`, or `4`. " + "Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), " "FA2 otherwise." ) if footnotes: diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 97a139c79..74bb3d68f 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -16,8 +16,8 @@ class AttentionConfig: backend: AttentionBackendEnum | None = None """Attention backend to use. If None, will be selected automatically.""" - flash_attn_version: Literal[2, 3] | None = None - """Force vllm to use a specific flash-attention version (2 or 3). + flash_attn_version: Literal[2, 3, 4] | None = None + """Force vllm to use a specific flash-attention version (2, 3, or 4). Only valid when using the flash-attention backend.""" use_prefill_decode_attention: bool = False diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index d444e20da..f6e7ab85d 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -2014,7 +2014,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # RoCM and the latter has an additional parameter to control # FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func - self.vllm_flash_attn_version = get_flash_attn_version() + self.vllm_flash_attn_version = get_flash_attn_version( + head_size=self.qk_head_dim + ) if self.vllm_flash_attn_version is not None: self.flash_attn_varlen_func = functools.partial( flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index d89366bbd..d902f2ebc 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -204,7 +204,9 @@ class MMEncoderAttention(CustomOp): } self._fa_version = ( - get_flash_attn_version() if self.is_flash_attn_backend else None + get_flash_attn_version(head_size=head_size) + if self.is_flash_attn_backend + else None ) if self.attn_backend == AttentionBackendEnum.FLASHINFER: diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 3150ad9a5..9658a7e3c 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -52,7 +52,9 @@ elif current_platform.is_rocm(): reshape_and_cache_flash = ops.reshape_and_cache_flash -def get_flash_attn_version(requires_alibi: bool = False) -> int | None: +def get_flash_attn_version( + requires_alibi: bool = False, head_size: int | None = None +) -> int | None: # import here to avoid circular dependencies from vllm.platforms import current_platform @@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: assert device_capability is not None # 1. default version depending on platform - fa_version = ( - 3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2 - ) + if device_capability.major == 9 and is_fa_version_supported(3): + # Hopper (SM90): prefer FA3 + fa_version = 3 + elif device_capability.major == 10 and is_fa_version_supported(4): + # Blackwell (SM100+, restrict to SM100 for now): prefer FA4 + fa_version = 4 + else: + # Fallback to FA2 + fa_version = 2 # 2. override if passed by environment or config from vllm.config import get_current_vllm_config_or_none @@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: fa_version = vllm_config.attention_config.flash_attn_version # 3. fallback for unsupported combinations - if device_capability.major == 10 and fa_version == 3: + if device_capability.major >= 10 and fa_version == 3: logger.warning_once( "Cannot use FA version 3 on Blackwell platform, " - "defaulting to FA version 2." + "defaulting to FA version 4 if supported, otherwise FA2." ) - fa_version = 2 + fa_version = 4 if is_fa_version_supported(4) else 2 if requires_alibi and fa_version == 3: logger.warning_once( @@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ) fa_version = 2 + if requires_alibi and fa_version == 4: + logger.warning_once( + "Cannot use FA version 4 with ALiBi, defaulting to FA version 2." + ) + fa_version = 2 + + # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict + # supported head dimensions. + # See: https://github.com/Dao-AILab/flash-attention/issues/1959 + if ( + fa_version == 4 + and device_capability.major >= 10 + and head_size is not None + and head_size > 128 + ): + logger.warning_once( + "FA4 on Blackwell does not support head_size=%d due to TMEM " + "capacity limits, defaulting to FA version 2.", + head_size, + ) + fa_version = 2 + if not is_fa_version_supported(fa_version): logger.error( "Cannot use FA version %d is not supported due to %s", @@ -139,6 +169,10 @@ def flash_attn_supports_mla(): return is_fa_version_supported( 3 ) and current_platform.is_device_capability_family(90) + + # NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard + # head dimensions (576 for qk, 512 for v) due to TMEM capacity limits. + except (ImportError, AssertionError): pass return False diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 940dc7515..91c49c55c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.attn_type = attn_type - self.vllm_flash_attn_version = get_flash_attn_version() + self.vllm_flash_attn_version = get_flash_attn_version( + requires_alibi=alibi_slopes is not None, + head_size=head_size, + ) + logger.info_once( + "Using FlashAttention version %s", + self.vllm_flash_attn_version, + scope="local", + ) # Cache the batch invariant result for use in forward passes self.batch_invariant_enabled = vllm_is_batch_invariant() diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 1578209e6..be459cd29 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -137,7 +137,7 @@ class CudagraphDispatcher: num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): - num_reqs = num_tokens_padded // uniform_decode_query_len + num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs) assert num_tokens_padded % uniform_decode_query_len == 0 else: uniform_decode = False diff --git a/vllm/vllm_flash_attn/__init__.py b/vllm/vllm_flash_attn/__init__.py new file mode 100644 index 000000000..3507defab --- /dev/null +++ b/vllm/vllm_flash_attn/__init__.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.vllm_flash_attn.flash_attn_interface import ( + FA2_AVAILABLE, + FA3_AVAILABLE, + fa_version_unsupported_reason, + flash_attn_varlen_func, + get_scheduler_metadata, + is_fa_version_supported, +) + +if not (FA2_AVAILABLE or FA3_AVAILABLE): + raise ImportError( + "vllm.vllm_flash_attn requires the CUDA flash attention extensions " + "(_vllm_fa2_C or _vllm_fa3_C). On ROCm, use upstream flash_attn." + ) + +__all__ = [ + "fa_version_unsupported_reason", + "flash_attn_varlen_func", + "get_scheduler_metadata", + "is_fa_version_supported", +] diff --git a/vllm/vllm_flash_attn/flash_attn_interface.py b/vllm/vllm_flash_attn/flash_attn_interface.py new file mode 100644 index 000000000..9d9a9be2f --- /dev/null +++ b/vllm/vllm_flash_attn/flash_attn_interface.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2023, Tri Dao. +# ruff: noqa: E501 + + +import torch + +# isort: off +# We need to import the CUDA kernels after importing torch +# Use relative import to support build-from-source installation in vLLM + +try: + from . import _vllm_fa2_C # type: ignore[attr-defined] # noqa: F401 + + FA2_UNAVAILABLE_REASON = None + FA2_AVAILABLE = True +except ImportError as e: + FA2_UNAVAILABLE_REASON = str(e) + FA2_AVAILABLE = False + +try: + from . import _vllm_fa3_C # type: ignore[attr-defined] # noqa: F401 + + FA3_UNAVAILABLE_REASON = None + FA3_AVAILABLE = True +except ImportError as e: + FA3_UNAVAILABLE_REASON = str(e) + FA3_AVAILABLE = False + + +try: + import os + + _cute_interface_path = os.path.join( + os.path.dirname(__file__), "cute", "interface.py" + ) + if not os.path.exists(_cute_interface_path): + raise ImportError("vllm.vllm_flash_attn.cute.interface not found") + + FA4_UNAVAILABLE_REASON = None + FA4_AVAILABLE = True +except (ImportError, ModuleNotFoundError) as e: + FA4_UNAVAILABLE_REASON = str(e) + FA4_AVAILABLE = False + +# isort: on + +DEFAULT_FA_VERSION = 2 + + +def _is_fa2_supported() -> tuple[bool, str | None]: + if not FA2_AVAILABLE: + return False, f"FA2 is unavailable due to: {FA2_UNAVAILABLE_REASON}" + from vllm.platforms import current_platform + + if not current_platform.has_device_capability(80): + return False, "FA2 is only supported on devices with compute capability >= 8" + return True, None + + +def _is_fa3_supported() -> tuple[bool, str | None]: + if not FA3_AVAILABLE: + return False, f"FA3 is unavailable due to: {FA3_UNAVAILABLE_REASON}" + from vllm.platforms import current_platform + + if not current_platform.is_device_capability_family(90): + return False, "FA3 is only supported on devices with compute capability 9.x" + return True, None + + +def _is_fa4_supported() -> tuple[bool, str | None]: + if not FA4_AVAILABLE: + return False, f"FA4 is unavailable due to: {FA4_UNAVAILABLE_REASON}" + from vllm.platforms import current_platform + + if not ( + current_platform.is_device_capability_family(90) + or current_platform.is_device_capability_family(100) + or current_platform.is_device_capability_family(110) + ): + return ( + False, + "FA4 is only supported on devices with compute capability 9.x, 10.x, or 11.x", + ) + return True, None + + +def is_fa_version_supported(fa_version: int) -> bool: + if fa_version == 2: + return _is_fa2_supported()[0] + elif fa_version == 3: + return _is_fa3_supported()[0] + elif fa_version == 4: + return _is_fa4_supported()[0] + else: + raise ValueError(f"Unsupported FA version: {fa_version}") + + +def fa_version_unsupported_reason(fa_version: int) -> str | None: + if fa_version == 2: + return _is_fa2_supported()[1] + elif fa_version == 3: + return _is_fa3_supported()[1] + elif fa_version == 4: + return _is_fa4_supported()[1] + else: + raise ValueError(f"Unsupported FA version: {fa_version}") + + +# +# For vLLM we only care about `flash_attn_varlen_func` and +# `flash_attn_with_kvcache` so we only maintain wrappers for these two. +# + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +# NOTE only used in FA3 +def get_scheduler_metadata( + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads_q, + num_heads_kv, + headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_k_new: torch.Tensor | None = None, + cache_leftpad: torch.Tensor | None = None, + page_size: int | None = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication +): + cache_seqlens = maybe_contiguous(cache_seqlens) + if headdim_v is None: + headdim_v = headdim + scheduler_metadata = torch.ops._vllm_fa3_C.get_scheduler_metadata( + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads_q, + num_heads_kv, + headdim, + headdim_v, + qkv_dtype, + cache_seqlens, + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_leftpad, + page_size, + max_seqlen_k_new, + causal, + window_size[0], + window_size[1], + has_softcap, + num_splits, + pack_gqa, + sm_margin, + ) + + return scheduler_metadata + + +def flash_attn_varlen_func( + q, + k, + v, + max_seqlen_q, + cu_seqlens_q, + max_seqlen_k, + cu_seqlens_k=None, # only used for non-paged prefill + seqused_k=None, + q_v=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size: list[int] | None = None, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None, + return_softmax_lse=False, + out=None, + # FA3 Only + scheduler_metadata=None, + q_descale=None, + k_descale=None, + v_descale=None, + num_splits: int = 0, + # Version selector + fa_version: int = DEFAULT_FA_VERSION, + s_aux=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqused_k=None, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + assert cu_seqlens_k is not None or seqused_k is not None, ( + "cu_seqlens_k or seqused_k must be provided" + ) + assert cu_seqlens_k is None or seqused_k is None, ( + "cu_seqlens_k and seqused_k cannot be provided at the same time" + ) + assert block_table is None or seqused_k is not None, ( + "seqused_k must be provided if block_table is provided" + ) + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + # custom op does not support non-tuple input + real_window_size: tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + + dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) + + if fa_version == 2: + if ( + scheduler_metadata is not None + and q_descale is not None + and k_descale is not None + and v_descale is not None + ): + raise NotImplementedError( + "FA2 does not support scheduler_metadata, q_descale, " + "k_descale, v_descale" + ) + if s_aux is not None: + raise NotImplementedError("FA2 does not support s_aux") + if num_splits > 1: + raise NotImplementedError("FA2 does not support num_splits > 1") + out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( + q, + k, + v, + out, + cu_seqlens_q, + # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp + # still wants it so we pass all zeros + dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, + seqused_k, + None, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + real_window_size[0], + real_window_size[1], + softcap, + return_softmax_lse and dropout_p > 0, + num_splits, + None, + ) + elif fa_version == 3: + assert alibi_slopes is None, "Alibi is not supported in FA3" + out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( + q, + k, + v, + None, + None, # k_new, v_new + q_v, + out, + cu_seqlens_q, + cu_seqlens_k, # cu_seqlens_k + None, # cu_seqlens_k_new + None, + seqused_k, # seqused_q, seqused_k + max_seqlen_q, + max_seqlen_k, + block_table, + None, # kv_batch_idx + None, # leftpad_k + None, + None, + None, # rotary_cos, rotary_sin, seqlens_rotary + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + real_window_size[0], + real_window_size[1], + softcap, + True, # rotary_interleaved + scheduler_metadata, + num_splits, + None, # pack_gqa + 0, # sm_margin + s_aux, # s_aux + cp_world_size, + cp_rank, + cp_tot_seqused_k, + ) + elif fa_version == 4: + assert alibi_slopes is None, "Alibi is not supported in FA4" + # FA4 on SM90 doesn't support paged KV; SM100+ does + from vllm.platforms import current_platform + + if block_table is not None and current_platform.is_device_capability_family(90): + raise NotImplementedError( + "FA4 with paged KV is not supported on SM90 (Hopper). " + "Use FA3 or upgrade to Blackwell (SM100+)." + ) + from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd + + out, softmax_lse = _flash_attn_fwd( + q, + k, + v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + page_table=block_table, + softmax_scale=softmax_scale, + causal=causal, + softcap=softcap, + window_size_left=real_window_size[0] if real_window_size[0] >= 0 else None, + window_size_right=real_window_size[1] if real_window_size[1] >= 0 else None, + num_splits=num_splits, + return_lse=return_softmax_lse, + out=out, + ) + else: + raise ValueError(f"Unsupported FA version: {fa_version}") + return (out, softmax_lse) if return_softmax_lse else out + + +def sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + *, + return_softmax_lse=False, + out=None, +): + """Compute attention with vertical and slash sparsity patterns. + Most Arguments are the same with the flash_attn_func interface, except for 4 extra args: + block_count and block_offset for slash sparsity patterns, and + column_count and column_index for vertical sparsity patterns. + For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) + column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, softmax_lse = torch.ops._vllm_fa2_C.fwd_sparse( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + out, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + softcap, + return_attn_probs and dropout_p > 0, + None, + ) + return (out, softmax_lse) if return_softmax_lse else out + + +def sparse_attn_varlen_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + *, + return_softmax_lse=False, + out=None, +): + """Compute attention with vertical and slash sparsity patterns. + Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args: + block_count and block_offset for slash sparsity patterns, and + column_count and column_index for vertical sparsity patterns. + For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) + column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd_sparse( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + out, + cu_seqlens_q, + cu_seqlens_k, + None, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + softcap, + return_attn_probs and dropout_p > 0, + None, + ) + return (out, softmax_lse) if return_softmax_lse else out