[MM Encoder] Default to use TORCH_SDPA backend for ViT on Volta/Turing GPU (#36472)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-03-09 18:43:44 +08:00
committed by GitHub
parent aaf5fa9abf
commit b0906d8b02
2 changed files with 30 additions and 7 deletions

View File

@@ -19,6 +19,7 @@ from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.interface import DeviceCapability
from vllm.platforms.rocm import RocmPlatform
from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@@ -83,6 +84,20 @@ def test_mha_attn_platform(default_vllm_config, device: str):
attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
# Test Turing (pre-Ampere, sm_75): FlashAttention requires sm>=80,
# and Triton no longer supports MMA on Turing, so we expect that
# TORCH_SDPA is used for MMEncoderAttention.
with (
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
patch.object(
CudaPlatform,
"get_device_capability",
return_value=DeviceCapability(major=7, minor=5),
),
):
attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
def ref_attention(
query: torch.Tensor,

View File

@@ -413,12 +413,20 @@ class CudaPlatformBase(Platform):
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASHINFER,
]
if cls.has_device_capability(80):
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASHINFER,
]
else:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLASHINFER,
]
@classmethod
def get_vit_attn_backend(
@@ -438,7 +446,7 @@ class CudaPlatformBase(Platform):
cc = cls.get_device_capability()
for vit_attn_backend in cls.get_supported_vit_attn_backends():
if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
continue
return vit_attn_backend
try:
backend_class = vit_attn_backend.get_class()
is_backend_supported = backend_class.supports_head_size(