[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -47,7 +47,7 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
)
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@@ -301,7 +301,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@@ -377,10 +377,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:
@@ -490,9 +491,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == _Backend.FLASH_ATTN:
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens