[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -31,7 +31,7 @@ from transformers.modeling_outputs import (
)
from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
@@ -580,8 +580,8 @@ class SiglipAttention(nn.Module):
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend_override: _Backend | None = None,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
@@ -621,8 +621,8 @@ class SiglipAttention(nn.Module):
)
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@@ -680,10 +680,10 @@ class SiglipAttention(nn.Module):
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == _Backend.TORCH_SDPA:
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
@@ -702,7 +702,7 @@ class SiglipAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
if seqlens is None:
raise ValueError("xFormers attention backend requires seqlens tensor.")
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
@@ -786,8 +786,8 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend_override: _Backend | None = None,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
):
super().__init__()
@@ -847,7 +847,7 @@ class SiglipEncoder(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -861,16 +861,16 @@ class SiglipEncoder(nn.Module):
)
self.use_upstream_fa = False
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
@@ -943,9 +943,12 @@ class SiglipEncoder(nn.Module):
max_seqlen = None
seqlens = None
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds
@@ -966,7 +969,7 @@ class SiglipVisionTransformer(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -1016,7 +1019,7 @@ class SiglipVisionModel(nn.Module):
config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()