[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
)
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@@ -198,7 +198,7 @@ class Qwen3_VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
@@ -306,7 +306,7 @@ class Qwen3_VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@@ -372,18 +372,18 @@ class Qwen3_VisionTransformer(nn.Module):
)
use_upstream_fa = False
if (
self.attn_backend != _Backend.FLASH_ATTN
and self.attn_backend != _Backend.ROCM_AITER_FA
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
@@ -510,11 +510,11 @@ class Qwen3_VisionTransformer(nn.Module):
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens