[MM Encoder] Default to use TORCH_SDPA backend for ViT on Volta/Turing GPU (#36472)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.attention import MMEncoderAttention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
@@ -83,6 +84,20 @@ def test_mha_attn_platform(default_vllm_config, device: str):
|
||||
attn = MMEncoderAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
|
||||
|
||||
# Test Turing (pre-Ampere, sm_75): FlashAttention requires sm>=80,
|
||||
# and Triton no longer supports MMA on Turing, so we expect that
|
||||
# TORCH_SDPA is used for MMEncoderAttention.
|
||||
with (
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
patch.object(
|
||||
CudaPlatform,
|
||||
"get_device_capability",
|
||||
return_value=DeviceCapability(major=7, minor=5),
|
||||
),
|
||||
):
|
||||
attn = MMEncoderAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
|
||||
def ref_attention(
|
||||
query: torch.Tensor,
|
||||
|
||||
@@ -413,12 +413,20 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
]
|
||||
if cls.has_device_capability(80):
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(
|
||||
@@ -438,7 +446,7 @@ class CudaPlatformBase(Platform):
|
||||
cc = cls.get_device_capability()
|
||||
for vit_attn_backend in cls.get_supported_vit_attn_backends():
|
||||
if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
continue
|
||||
return vit_attn_backend
|
||||
try:
|
||||
backend_class = vit_attn_backend.get_class()
|
||||
is_backend_supported = backend_class.supports_head_size(
|
||||
|
||||
Reference in New Issue
Block a user