[Attention] FA4 integration (#32974)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -52,7 +52,9 @@ elif current_platform.is_rocm():
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
def get_flash_attn_version(
|
||||
requires_alibi: bool = False, head_size: int | None = None
|
||||
) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = (
|
||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||
)
|
||||
if device_capability.major == 9 and is_fa_version_supported(3):
|
||||
# Hopper (SM90): prefer FA3
|
||||
fa_version = 3
|
||||
elif device_capability.major == 10 and is_fa_version_supported(4):
|
||||
# Blackwell (SM100+, restrict to SM100 for now): prefer FA4
|
||||
fa_version = 4
|
||||
else:
|
||||
# Fallback to FA2
|
||||
fa_version = 2
|
||||
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
@@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
if device_capability.major >= 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform, "
|
||||
"defaulting to FA version 2."
|
||||
"defaulting to FA version 4 if supported, otherwise FA2."
|
||||
)
|
||||
fa_version = 2
|
||||
fa_version = 4 if is_fa_version_supported(4) else 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once(
|
||||
@@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if requires_alibi and fa_version == 4:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
|
||||
# supported head dimensions.
|
||||
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
|
||||
if (
|
||||
fa_version == 4
|
||||
and device_capability.major >= 10
|
||||
and head_size is not None
|
||||
and head_size > 128
|
||||
):
|
||||
logger.warning_once(
|
||||
"FA4 on Blackwell does not support head_size=%d due to TMEM "
|
||||
"capacity limits, defaulting to FA version 2.",
|
||||
head_size,
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error(
|
||||
"Cannot use FA version %d is not supported due to %s",
|
||||
@@ -139,6 +169,10 @@ def flash_attn_supports_mla():
|
||||
return is_fa_version_supported(
|
||||
3
|
||||
) and current_platform.is_device_capability_family(90)
|
||||
|
||||
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
|
||||
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
@@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
requires_alibi=alibi_slopes is not None,
|
||||
head_size=head_size,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using FlashAttention version %s",
|
||||
self.vllm_flash_attn_version,
|
||||
scope="local",
|
||||
)
|
||||
# Cache the batch invariant result for use in forward passes
|
||||
self.batch_invariant_enabled = vllm_is_batch_invariant()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user