diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index c56dcb443..4aef72821 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -39,7 +39,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG c0ec424fd8a546d0cbbf4bf050bbcfe837c55afb + GIT_TAG f5bc33cfc02c744d24a2e9d50e6db656de40611c GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index a4423b301..97fc35e70 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -154,6 +154,17 @@ def get_flash_attn_version( return None +def is_fa_version_supported(fa_version: int) -> bool: + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + is_fa_version_supported as _is_fa_version_supported, + ) + + return _is_fa_version_supported(fa_version) + except ImportError: + return False + + def flash_attn_supports_fp8() -> bool: return ( get_flash_attn_version() == 3 diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5e81cae42..b9ccd4fce 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,6 +10,7 @@ import numpy as np import torch from vllm.model_executor.layers.attention import Attention +from vllm.platforms import current_platform from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( AttentionBackend, @@ -20,6 +21,7 @@ from vllm.v1.attention.backend import ( from vllm.v1.attention.backends.fa_utils import ( flash_attn_supports_fp8, get_flash_attn_version, + is_fa_version_supported, is_flash_attn_varlen_func_available, ) from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens @@ -45,7 +47,6 @@ from vllm.config import ( from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv, round_up from vllm.v1.attention.backend import ( @@ -170,7 +171,13 @@ class FlashAttentionBackend(AttentionBackend): @classmethod def supports_head_size(cls, head_size: int) -> bool: - return head_size % 8 == 0 and head_size <= 256 + if head_size % 8 != 0: + return False + if head_size <= 256: + return True + if is_fa_version_supported(4): + return head_size <= 512 + return False @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: @@ -618,6 +625,14 @@ class FlashAttentionImpl(AttentionImpl): requires_alibi=alibi_slopes is not None, head_size=head_size, ) + # head_size > 256 requires FA4 on SM90+; force upgrade from FA3 + if ( + head_size > 256 + and self.vllm_flash_attn_version == 3 + and current_platform.is_cuda() + and current_platform.is_device_capability_family(90) + ): + self.vllm_flash_attn_version = 4 logger.info_once( "Using FlashAttention version %s", self.vllm_flash_attn_version, diff --git a/vllm/vllm_flash_attn/flash_attn_interface.py b/vllm/vllm_flash_attn/flash_attn_interface.py index 9d9a9be2f..eb0dbd423 100644 --- a/vllm/vllm_flash_attn/flash_attn_interface.py +++ b/vllm/vllm_flash_attn/flash_attn_interface.py @@ -366,14 +366,7 @@ def flash_attn_varlen_func( ) elif fa_version == 4: assert alibi_slopes is None, "Alibi is not supported in FA4" - # FA4 on SM90 doesn't support paged KV; SM100+ does - from vllm.platforms import current_platform - if block_table is not None and current_platform.is_device_capability_family(90): - raise NotImplementedError( - "FA4 with paged KV is not supported on SM90 (Hopper). " - "Use FA3 or upgrade to Blackwell (SM100+)." - ) from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd out, softmax_lse = _flash_attn_fwd(