diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index bc99ed576..3bcde3b0a 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.platforms import current_platform from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cuda import CudaPlatform +from vllm.platforms.interface import DeviceCapability from vllm.platforms.rocm import RocmPlatform from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -83,6 +84,20 @@ def test_mha_attn_platform(default_vllm_config, device: str): attn = MMEncoderAttention(16, 72, scale=1) assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN + # Test Turing (pre-Ampere, sm_75): FlashAttention requires sm>=80, + # and Triton no longer supports MMA on Turing, so we expect that + # TORCH_SDPA is used for MMEncoderAttention. + with ( + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), + patch.object( + CudaPlatform, + "get_device_capability", + return_value=DeviceCapability(major=7, minor=5), + ), + ): + attn = MMEncoderAttention(16, 64, scale=1) + assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA + def ref_attention( query: torch.Tensor, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index d3d75d883..651cf86b1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -413,12 +413,20 @@ class CudaPlatformBase(Platform): @classmethod def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: - return [ - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TRITON_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.FLASHINFER, - ] + if cls.has_device_capability(80): + return [ + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.FLASHINFER, + ] + else: + return [ + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.FLASHINFER, + ] @classmethod def get_vit_attn_backend( @@ -438,7 +446,7 @@ class CudaPlatformBase(Platform): cc = cls.get_device_capability() for vit_attn_backend in cls.get_supported_vit_attn_backends(): if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA: - continue + return vit_attn_backend try: backend_class = vit_attn_backend.get_class() is_backend_supported = backend_class.supports_head_size(