[Attention] FA4 integration (#32974)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2026-03-01 18:44:57 -05:00
committed by GitHub
parent 57a96e26c9
commit 8b5014d3dd
15 changed files with 818 additions and 55 deletions

View File

@@ -52,7 +52,9 @@ elif current_platform.is_rocm():
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None
) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
@@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
assert device_capability is not None
# 1. default version depending on platform
fa_version = (
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
if device_capability.major == 9 and is_fa_version_supported(3):
# Hopper (SM90): prefer FA3
fa_version = 3
elif device_capability.major == 10 and is_fa_version_supported(4):
# Blackwell (SM100+, restrict to SM100 for now): prefer FA4
fa_version = 4
else:
# Fallback to FA2
fa_version = 2
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none
@@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
if device_capability.major >= 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 2."
"defaulting to FA version 4 if supported, otherwise FA2."
)
fa_version = 2
fa_version = 4 if is_fa_version_supported(4) else 2
if requires_alibi and fa_version == 3:
logger.warning_once(
@@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
)
fa_version = 2
if requires_alibi and fa_version == 4:
logger.warning_once(
"Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# supported head dimensions.
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
if (
fa_version == 4
and device_capability.major >= 10
and head_size is not None
and head_size > 128
):
logger.warning_once(
"FA4 on Blackwell does not support head_size=%d due to TMEM "
"capacity limits, defaulting to FA version 2.",
head_size,
)
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error(
"Cannot use FA version %d is not supported due to %s",
@@ -139,6 +169,10 @@ def flash_attn_supports_mla():
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
except (ImportError, AssertionError):
pass
return False

View File

@@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
logger.info_once(
"Using FlashAttention version %s",
self.vllm_flash_attn_version,
scope="local",
)
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant()