[Attention] FA4 integration (#32974)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2026-03-01 18:44:57 -05:00
committed by GitHub
parent 57a96e26c9
commit 8b5014d3dd
15 changed files with 818 additions and 55 deletions

View File

@@ -9,6 +9,7 @@ steps:
- tests/v1
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# split the test to avoid interference
- pytest -v -s -m 'not cpu_test' v1/core
- pytest -v -s v1/executor

2
.gitignore vendored
View File

@@ -3,6 +3,8 @@
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/__init__.py
!vllm/vllm_flash_attn/flash_attn_interface.py
# OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/*

View File

@@ -17,7 +17,8 @@ endif()
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3),
# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files).
# If no component is specified, vllm-flash-attn is still installed.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
@@ -38,7 +39,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6
GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
@@ -46,38 +47,62 @@ else()
endif()
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS)
# Make sure vllm-flash-attn install rules are nested under vllm/
# This is here to support installing all components under the same prefix with cmake --install.
# setup.py installs every component separately but uses the same prefix for all.
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3,
# and these statements don't hurt when installing neither component.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
# Install rules for FA components need the install prefix nested under vllm/
# These run at install time, before the FA library's own install rules
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT ${_FA_COMPONENT})
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT ${_FA_COMPONENT})
endforeach()
# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# Restore the install prefix after FA's install rules
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT ${_FA_COMPONENT})
endforeach()
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
# case only one is built, in the case both are built redundant work is done)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py"
)
# Install shared Python files for both FA2 and FA3 components
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")"
COMPONENT ${_FA_COMPONENT})
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN "*.py"
)
# Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py
# which are source-controlled in vllm)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT ${_FA_COMPONENT}
FILES_MATCHING PATTERN "*.py"
PATTERN "__init__.py" EXCLUDE
PATTERN "flash_attn_interface.py" EXCLUDE
)
endforeach()
#
# FA4 CuteDSL component
# This is a Python-only component that copies the flash_attn/cute directory
# and transforms imports to match our package structure.
#
add_custom_target(_vllm_fa4_cutedsl_C)
# Copy flash_attn/cute directory (needed for FA4) and transform imports
# The cute directory uses flash_attn.cute imports internally, which we replace
# with vllm.vllm_flash_attn.cute to match our package structure.
install(CODE "
file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
foreach(SRC_FILE \${CUTE_PY_FILES})
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
file(MAKE_DIRECTORY \${DST_DIR})
file(READ \${SRC_FILE} FILE_CONTENTS)
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
endforeach()
" COMPONENT _vllm_fa4_cutedsl_C)

View File

@@ -168,6 +168,7 @@ Priority is **1 = highest** (tried first).
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
@@ -178,7 +179,7 @@ Priority is **1 = highest** (tried first).
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
>
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, FA2 otherwise.
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2`, `3`, or `4`. Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), FA2 otherwise.
## MLA (Multi-head Latent Attention) Backends

View File

@@ -11,3 +11,7 @@ torchaudio==2.10.0
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.4
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
nvidia-cutlass-dsl>=4.4.0.dev1
quack-kernels>=0.2.7

View File

@@ -976,6 +976,11 @@ if _is_cuda():
):
# FA3 requires CUDA 12.3 or later
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
# FA4 CuteDSL - Python-only component for FA4's cute DSL support
# Optional since this doesn't produce a .so file, just copies Python files
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa4_cutedsl_C", optional=True)
)
if envs.VLLM_USE_PRECOMPILED or (
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9")
):

View File

@@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
# ---------------------------------------------------------------------------
# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill)
# Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill)
# ---------------------------------------------------------------------------
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences.
def _parse_fa4_supported_caps() -> str | None:
"""Parse flash_attn_interface.py for FA4 supported compute capabilities.
Returns a dict with 'fa2' and 'fa3' keys containing their respective
Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported().
"""
fa_interface_file = (
REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py"
)
if not fa_interface_file.exists():
return None
try:
tree = ast.parse(fa_interface_file.read_text())
except Exception:
return None
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported":
continue
for n in ast.walk(node):
if not (
isinstance(n, ast.Compare)
and len(n.ops) == 1
and isinstance(n.ops[0], ast.NotIn)
and isinstance(n.comparators[0], ast.List)
):
continue
caps: list[int] = [
e.value
for e in n.comparators[0].elts
if isinstance(e, ast.Constant) and isinstance(e.value, int)
]
if caps:
caps.sort()
return f"{caps[0]}.x-{caps[-1]}.x"
return None
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"""Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences.
Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective
feature overrides for compute capability, KV cache dtypes, and sink support.
"""
if not FA_UTILS_FILE.exists():
@@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_fp8 = False
fa3_supports_sinks = False
fa3_compute_cap: str | None = None
fa4_compute_cap: str | None = None
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef):
@@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_sinks = True
break
# Check get_flash_attn_version for FA3 compute capability
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
# Check get_flash_attn_version for FA3/FA4 compute capability
if node.name == "get_flash_attn_version":
for n in ast.walk(node):
# Look for IfExp (ternary) with `device_capability.major == 9`
# Handle IfExp (ternary) with `device_capability.major == 9`
if isinstance(n, ast.IfExp):
test = n.test
# Check if test is a BoolOp (and) containing the major check
if isinstance(test, ast.BoolOp):
for val in test.values:
if (
@@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_compute_cap = f"{val.comparators[0].value}.x"
break
# Handle If statements for FA3/FA4 detection
# e.g. `if device_capability.major == 9` -> FA3
# `elif device_capability.major >= 10` -> FA4
if isinstance(n, ast.If):
test = n.test
comparisons = (
[v for v in test.values if isinstance(v, ast.Compare)]
if isinstance(test, ast.BoolOp)
else [test]
if isinstance(test, ast.Compare)
else []
)
for comp in comparisons:
if not (
isinstance(comp.left, ast.Attribute)
and comp.left.attr == "major"
and comp.comparators
and isinstance(comp.comparators[0], ast.Constant)
and isinstance(comp.comparators[0].value, int)
):
continue
op = comp.ops[0]
val = comp.comparators[0].value
if isinstance(op, ast.Eq) and fa3_compute_cap is None:
fa3_compute_cap = f"{val}.x"
elif isinstance(op, ast.GtE) and fa4_compute_cap is None:
fa4_compute_cap = f"{val}.0"
# Fallback: try to parse FA4 compute caps from flash_attn_interface.py
if fa4_compute_cap is None:
fa4_compute_cap = _parse_fa4_supported_caps()
return {
"fa2": {
"supports_fp8": False,
@@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"supports_fp8": fa3_supports_fp8,
"supports_sink": fa3_supports_sinks,
},
"fa4": {
"compute_capability": fa4_compute_cap,
"supports_fp8": False,
"supports_sink": False,
},
}
@@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]:
# ---------------------------------------------------------------------------
# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM)
# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
# ---------------------------------------------------------------------------
@@ -768,7 +843,7 @@ def _expand_flash_attn_variants(
all_backends: list[dict[str, Any]],
fa_features: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
"""Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities."""
"""Expand FLASH_ATTN into FA2, FA3, and FA4 variants."""
expanded = []
for backend in all_backends:
if backend["name"] != "FLASH_ATTN":
@@ -801,6 +876,18 @@ def _expand_flash_attn_variants(
expanded.append(fa2)
expanded.append(fa3)
# Create FA4 entry if FA4 features are available
if "fa4" in fa_features:
fa4 = backend.copy()
fa4["version"] = "FA4*"
fa4["_sort_key"] = "FLASH_ATTN"
fa4["_sort_order"] = 2
if fa_features["fa4"].get("compute_capability"):
fa4["compute_capability"] = fa_features["fa4"]["compute_capability"]
fa4["supports_sink"] = fa_features["fa4"]["supports_sink"]
expanded.append(fa4)
return expanded
@@ -1360,7 +1447,8 @@ def generate_docs() -> str:
if fa_features:
footnotes.append(
"> **\\*** Specify the FlashAttention version via "
"`--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, "
"`--attention-config.flash_attn_version=2`, `3`, or `4`. "
"Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), "
"FA2 otherwise."
)
if footnotes:

View File

@@ -16,8 +16,8 @@ class AttentionConfig:
backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3).
flash_attn_version: Literal[2, 3, 4] | None = None
"""Force vllm to use a specific flash-attention version (2, 3, or 4).
Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False

View File

@@ -2014,7 +2014,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# RoCM and the latter has an additional parameter to control
# FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
self.vllm_flash_attn_version = get_flash_attn_version(
head_size=self.qk_head_dim
)
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version

View File

@@ -204,7 +204,9 @@ class MMEncoderAttention(CustomOp):
}
self._fa_version = (
get_flash_attn_version() if self.is_flash_attn_backend else None
get_flash_attn_version(head_size=head_size)
if self.is_flash_attn_backend
else None
)
if self.attn_backend == AttentionBackendEnum.FLASHINFER:

View File

@@ -52,7 +52,9 @@ elif current_platform.is_rocm():
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None
) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
@@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
assert device_capability is not None
# 1. default version depending on platform
fa_version = (
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
if device_capability.major == 9 and is_fa_version_supported(3):
# Hopper (SM90): prefer FA3
fa_version = 3
elif device_capability.major == 10 and is_fa_version_supported(4):
# Blackwell (SM100+, restrict to SM100 for now): prefer FA4
fa_version = 4
else:
# Fallback to FA2
fa_version = 2
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none
@@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
if device_capability.major >= 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 2."
"defaulting to FA version 4 if supported, otherwise FA2."
)
fa_version = 2
fa_version = 4 if is_fa_version_supported(4) else 2
if requires_alibi and fa_version == 3:
logger.warning_once(
@@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
)
fa_version = 2
if requires_alibi and fa_version == 4:
logger.warning_once(
"Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# supported head dimensions.
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
if (
fa_version == 4
and device_capability.major >= 10
and head_size is not None
and head_size > 128
):
logger.warning_once(
"FA4 on Blackwell does not support head_size=%d due to TMEM "
"capacity limits, defaulting to FA version 2.",
head_size,
)
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error(
"Cannot use FA version %d is not supported due to %s",
@@ -139,6 +169,10 @@ def flash_attn_supports_mla():
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
except (ImportError, AssertionError):
pass
return False

View File

@@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
logger.info_once(
"Using FlashAttention version %s",
self.vllm_flash_attn_version,
scope="local",
)
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant()

View File

@@ -137,7 +137,7 @@ class CudagraphDispatcher:
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs)
assert num_tokens_padded % uniform_decode_query_len == 0
else:
uniform_decode = False

View File

@@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.vllm_flash_attn.flash_attn_interface import (
FA2_AVAILABLE,
FA3_AVAILABLE,
fa_version_unsupported_reason,
flash_attn_varlen_func,
get_scheduler_metadata,
is_fa_version_supported,
)
if not (FA2_AVAILABLE or FA3_AVAILABLE):
raise ImportError(
"vllm.vllm_flash_attn requires the CUDA flash attention extensions "
"(_vllm_fa2_C or _vllm_fa3_C). On ROCm, use upstream flash_attn."
)
__all__ = [
"fa_version_unsupported_reason",
"flash_attn_varlen_func",
"get_scheduler_metadata",
"is_fa_version_supported",
]

View File

@@ -0,0 +1,567 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2023, Tri Dao.
# ruff: noqa: E501
import torch
# isort: off
# We need to import the CUDA kernels after importing torch
# Use relative import to support build-from-source installation in vLLM
try:
from . import _vllm_fa2_C # type: ignore[attr-defined] # noqa: F401
FA2_UNAVAILABLE_REASON = None
FA2_AVAILABLE = True
except ImportError as e:
FA2_UNAVAILABLE_REASON = str(e)
FA2_AVAILABLE = False
try:
from . import _vllm_fa3_C # type: ignore[attr-defined] # noqa: F401
FA3_UNAVAILABLE_REASON = None
FA3_AVAILABLE = True
except ImportError as e:
FA3_UNAVAILABLE_REASON = str(e)
FA3_AVAILABLE = False
try:
import os
_cute_interface_path = os.path.join(
os.path.dirname(__file__), "cute", "interface.py"
)
if not os.path.exists(_cute_interface_path):
raise ImportError("vllm.vllm_flash_attn.cute.interface not found")
FA4_UNAVAILABLE_REASON = None
FA4_AVAILABLE = True
except (ImportError, ModuleNotFoundError) as e:
FA4_UNAVAILABLE_REASON = str(e)
FA4_AVAILABLE = False
# isort: on
DEFAULT_FA_VERSION = 2
def _is_fa2_supported() -> tuple[bool, str | None]:
if not FA2_AVAILABLE:
return False, f"FA2 is unavailable due to: {FA2_UNAVAILABLE_REASON}"
from vllm.platforms import current_platform
if not current_platform.has_device_capability(80):
return False, "FA2 is only supported on devices with compute capability >= 8"
return True, None
def _is_fa3_supported() -> tuple[bool, str | None]:
if not FA3_AVAILABLE:
return False, f"FA3 is unavailable due to: {FA3_UNAVAILABLE_REASON}"
from vllm.platforms import current_platform
if not current_platform.is_device_capability_family(90):
return False, "FA3 is only supported on devices with compute capability 9.x"
return True, None
def _is_fa4_supported() -> tuple[bool, str | None]:
if not FA4_AVAILABLE:
return False, f"FA4 is unavailable due to: {FA4_UNAVAILABLE_REASON}"
from vllm.platforms import current_platform
if not (
current_platform.is_device_capability_family(90)
or current_platform.is_device_capability_family(100)
or current_platform.is_device_capability_family(110)
):
return (
False,
"FA4 is only supported on devices with compute capability 9.x, 10.x, or 11.x",
)
return True, None
def is_fa_version_supported(fa_version: int) -> bool:
if fa_version == 2:
return _is_fa2_supported()[0]
elif fa_version == 3:
return _is_fa3_supported()[0]
elif fa_version == 4:
return _is_fa4_supported()[0]
else:
raise ValueError(f"Unsupported FA version: {fa_version}")
def fa_version_unsupported_reason(fa_version: int) -> str | None:
if fa_version == 2:
return _is_fa2_supported()[1]
elif fa_version == 3:
return _is_fa3_supported()[1]
elif fa_version == 4:
return _is_fa4_supported()[1]
else:
raise ValueError(f"Unsupported FA version: {fa_version}")
#
# For vLLM we only care about `flash_attn_varlen_func` and
# `flash_attn_with_kvcache` so we only maintain wrappers for these two.
#
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
# NOTE only used in FA3
def get_scheduler_metadata(
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads_q,
num_heads_kv,
headdim,
cache_seqlens: torch.Tensor,
qkv_dtype=torch.bfloat16,
headdim_v=None,
cu_seqlens_q: torch.Tensor | None = None,
cu_seqlens_k_new: torch.Tensor | None = None,
cache_leftpad: torch.Tensor | None = None,
page_size: int | None = None,
max_seqlen_k_new=0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
has_softcap=False,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
):
cache_seqlens = maybe_contiguous(cache_seqlens)
if headdim_v is None:
headdim_v = headdim
scheduler_metadata = torch.ops._vllm_fa3_C.get_scheduler_metadata(
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads_q,
num_heads_kv,
headdim,
headdim_v,
qkv_dtype,
cache_seqlens,
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_leftpad,
page_size,
max_seqlen_k_new,
causal,
window_size[0],
window_size[1],
has_softcap,
num_splits,
pack_gqa,
sm_margin,
)
return scheduler_metadata
def flash_attn_varlen_func(
q,
k,
v,
max_seqlen_q,
cu_seqlens_q,
max_seqlen_k,
cu_seqlens_k=None, # only used for non-paged prefill
seqused_k=None,
q_v=None,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size: list[int] | None = None,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
block_table=None,
return_softmax_lse=False,
out=None,
# FA3 Only
scheduler_metadata=None,
q_descale=None,
k_descale=None,
v_descale=None,
num_splits: int = 0,
# Version selector
fa_version: int = DEFAULT_FA_VERSION,
s_aux=None,
cp_world_size=1,
cp_rank=0,
cp_tot_seqused_k=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert cu_seqlens_k is not None or seqused_k is not None, (
"cu_seqlens_k or seqused_k must be provided"
)
assert cu_seqlens_k is None or seqused_k is None, (
"cu_seqlens_k and seqused_k cannot be provided at the same time"
)
assert block_table is None or seqused_k is not None, (
"seqused_k must be provided if block_table is provided"
)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
# custom op does not support non-tuple input
real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
if fa_version == 2:
if (
scheduler_metadata is not None
and q_descale is not None
and k_descale is not None
and v_descale is not None
):
raise NotImplementedError(
"FA2 does not support scheduler_metadata, q_descale, "
"k_descale, v_descale"
)
if s_aux is not None:
raise NotImplementedError("FA2 does not support s_aux")
if num_splits > 1:
raise NotImplementedError("FA2 does not support num_splits > 1")
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
q,
k,
v,
out,
cu_seqlens_q,
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# still wants it so we pass all zeros
dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
seqused_k,
None,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
real_window_size[0],
real_window_size[1],
softcap,
return_softmax_lse and dropout_p > 0,
num_splits,
None,
)
elif fa_version == 3:
assert alibi_slopes is None, "Alibi is not supported in FA3"
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
q,
k,
v,
None,
None, # k_new, v_new
q_v,
out,
cu_seqlens_q,
cu_seqlens_k, # cu_seqlens_k
None, # cu_seqlens_k_new
None,
seqused_k, # seqused_q, seqused_k
max_seqlen_q,
max_seqlen_k,
block_table,
None, # kv_batch_idx
None, # leftpad_k
None,
None,
None, # rotary_cos, rotary_sin, seqlens_rotary
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
real_window_size[0],
real_window_size[1],
softcap,
True, # rotary_interleaved
scheduler_metadata,
num_splits,
None, # pack_gqa
0, # sm_margin
s_aux, # s_aux
cp_world_size,
cp_rank,
cp_tot_seqused_k,
)
elif fa_version == 4:
assert alibi_slopes is None, "Alibi is not supported in FA4"
# FA4 on SM90 doesn't support paged KV; SM100+ does
from vllm.platforms import current_platform
if block_table is not None and current_platform.is_device_capability_family(90):
raise NotImplementedError(
"FA4 with paged KV is not supported on SM90 (Hopper). "
"Use FA3 or upgrade to Blackwell (SM100+)."
)
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd
out, softmax_lse = _flash_attn_fwd(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqused_k=seqused_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
page_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size_left=real_window_size[0] if real_window_size[0] >= 0 else None,
window_size_right=real_window_size[1] if real_window_size[1] >= 0 else None,
num_splits=num_splits,
return_lse=return_softmax_lse,
out=out,
)
else:
raise ValueError(f"Unsupported FA version: {fa_version}")
return (out, softmax_lse) if return_softmax_lse else out
def sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops._vllm_fa2_C.fwd_sparse(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out
def sparse_attn_varlen_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd_sparse(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
cu_seqlens_q,
cu_seqlens_k,
None,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out