[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -9,7 +9,7 @@ import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
@@ -256,7 +256,7 @@ class DotsVisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@@ -303,17 +303,17 @@ class DotsVisionAttention(nn.Module):
)
)
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Unsupported vision attention backend: {self.attn_backend}"
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward(
@@ -361,7 +361,7 @@ class DotsVisionAttention(nn.Module):
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
elif self.attn_backend == _Backend.TORCH_SDPA:
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
s = int(cu_seqlens[i - 1])
@@ -373,7 +373,7 @@ class DotsVisionAttention(nn.Module):
out_i = out_i.permute(0, 2, 1, 3)
outputs.append(out_i)
context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@@ -514,7 +514,7 @@ class DotsVisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
@@ -567,7 +567,7 @@ class DotsVisionTransformer(nn.Module):
require_post_norm: bool | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.config = config
@@ -582,10 +582,11 @@ class DotsVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers
num_layers = (
@@ -666,11 +667,11 @@ class DotsVisionTransformer(nn.Module):
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens

View File

@@ -36,7 +36,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import BatchFeature, PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
@@ -164,7 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module):
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@@ -211,17 +211,17 @@ class Ernie4_5_VisionAttention(nn.Module):
)
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Ernie45-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@@ -291,7 +291,7 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA:
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
@@ -310,7 +310,7 @@ class Ernie4_5_VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@@ -370,7 +370,7 @@ class Ernie4_5_VisionBlock(nn.Module):
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@@ -463,7 +463,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
patch_size = vision_config.patch_size
@@ -515,10 +515,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:
@@ -565,11 +566,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens

View File

@@ -46,7 +46,7 @@ from transformers.models.glm4v.image_processing_glm4v import (
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
@@ -252,7 +252,7 @@ class Glm4vVisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@@ -306,18 +306,18 @@ class Glm4vVisionAttention(nn.Module):
)
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@@ -377,7 +377,7 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA:
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
@@ -396,7 +396,7 @@ class Glm4vVisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@@ -425,7 +425,7 @@ class Glm4vVisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@@ -703,7 +703,7 @@ class Glm4vVisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@@ -772,10 +772,11 @@ class Glm4vVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:
@@ -824,8 +825,8 @@ class Glm4vVisionTransformer(nn.Module):
max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens

View File

@@ -16,7 +16,7 @@ from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
)
@@ -360,7 +360,7 @@ class KeyeSiglipAttention(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -414,17 +414,17 @@ class KeyeSiglipAttention(nn.Module):
)
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward(
@@ -489,7 +489,7 @@ class KeyeSiglipAttention(nn.Module):
softmax_scale=self.scale,
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@@ -536,7 +536,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@@ -590,7 +590,7 @@ class KeyeSiglipEncoder(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -685,7 +685,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -768,7 +768,7 @@ class KeyeSiglipVisionModel(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()

View File

@@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear
@@ -106,7 +106,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -135,7 +135,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
model_type = config.model_type
if model_type == "siglip2_navit":

View File

@@ -31,7 +31,7 @@ from transformers.modeling_outputs import (
)
from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
@@ -580,8 +580,8 @@ class SiglipAttention(nn.Module):
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend_override: _Backend | None = None,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
@@ -621,8 +621,8 @@ class SiglipAttention(nn.Module):
)
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@@ -680,10 +680,10 @@ class SiglipAttention(nn.Module):
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == _Backend.TORCH_SDPA:
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
@@ -702,7 +702,7 @@ class SiglipAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
if seqlens is None:
raise ValueError("xFormers attention backend requires seqlens tensor.")
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
@@ -786,8 +786,8 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend_override: _Backend | None = None,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
):
super().__init__()
@@ -847,7 +847,7 @@ class SiglipEncoder(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -861,16 +861,16 @@ class SiglipEncoder(nn.Module):
)
self.use_upstream_fa = False
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
@@ -943,9 +943,12 @@ class SiglipEncoder(nn.Module):
max_seqlen = None
seqlens = None
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
hidden_states = inputs_embeds
@@ -966,7 +969,7 @@ class SiglipVisionTransformer(nn.Module):
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -1016,7 +1019,7 @@ class SiglipVisionModel(nn.Module):
config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()

View File

@@ -42,7 +42,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLVisionConfig,
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
@@ -315,9 +315,9 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@@ -364,13 +364,16 @@ class Qwen2_5_VisionAttention(nn.Module):
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
from vllm.platforms import current_platform
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@@ -431,10 +434,10 @@ class Qwen2_5_VisionAttention(nn.Module):
cu_seqlens,
max_seqlen,
batch_size,
self.attn_backend == _Backend.ROCM_AITER_FA,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == _Backend.TORCH_SDPA:
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
@@ -450,7 +453,7 @@ class Qwen2_5_VisionAttention(nn.Module):
v,
cu_seqlens,
)
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
output, _ = self.proj(context_layer)
@@ -478,9 +481,9 @@ class Qwen2_5_VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@@ -656,7 +659,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@@ -708,10 +711,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
@@ -850,9 +853,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens

View File

@@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
@@ -329,7 +329,7 @@ class Qwen2VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@@ -378,18 +378,18 @@ class Qwen2VisionAttention(nn.Module):
)
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@@ -460,7 +460,7 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
elif self.attn_backend == _Backend.TORCH_SDPA:
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
@@ -485,7 +485,7 @@ class Qwen2VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@@ -515,7 +515,7 @@ class Qwen2VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@@ -679,7 +679,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@@ -739,10 +739,11 @@ class Qwen2VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:
@@ -789,9 +790,12 @@ class Qwen2VisionTransformer(nn.Module):
self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return max_seqlen, seqlens

View File

@@ -47,7 +47,7 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
)
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@@ -301,7 +301,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@@ -377,10 +377,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:
@@ -490,9 +491,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == _Backend.FLASH_ATTN:
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens

View File

@@ -49,7 +49,7 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
)
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
@@ -198,7 +198,7 @@ class Qwen3_VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
@@ -306,7 +306,7 @@ class Qwen3_VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
@@ -372,18 +372,18 @@ class Qwen3_VisionTransformer(nn.Module):
)
use_upstream_fa = False
if (
self.attn_backend != _Backend.FLASH_ATTN
and self.attn_backend != _Backend.ROCM_AITER_FA
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = _Backend.FLASH_ATTN
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
@@ -510,11 +510,11 @@ class Qwen3_VisionTransformer(nn.Module):
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens

View File

@@ -12,7 +12,7 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@@ -208,7 +208,7 @@ class Siglip2Attention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -264,14 +264,14 @@ class Siglip2Attention(nn.Module):
)
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
self.attn_backend = _Backend.TORCH_SDPA
self.attn_backend = AttentionBackendEnum.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
def forward(
@@ -308,7 +308,7 @@ class Siglip2Attention(nn.Module):
attn_output = self.flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
).reshape(seq_length, -1)
elif self.attn_backend == _Backend.TORCH_SDPA:
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
batch_size = cu_seqlens.shape[0] - 1
outputs = []
@@ -376,7 +376,7 @@ class Siglip2EncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@@ -440,7 +440,7 @@ class Siglip2Encoder(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -626,7 +626,7 @@ class Siglip2VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -667,7 +667,7 @@ class Siglip2NavitModel(torch.nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()

View File

@@ -10,7 +10,7 @@ from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
import torch
from transformers import PretrainedConfig
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
@@ -83,8 +83,8 @@ def get_vit_attn_backend(
head_size: int,
dtype: torch.dtype,
*,
attn_backend_override: _Backend | None = None,
) -> _Backend:
attn_backend_override: AttentionBackendEnum | None = None,
) -> AttentionBackendEnum:
"""
Get the available attention backend for Vision Transformer.
"""
@@ -94,7 +94,7 @@ def get_vit_attn_backend(
# Lazy import to avoid circular dependency
from vllm.attention.selector import get_env_variable_attn_backend
selected_backend: _Backend | None = get_env_variable_attn_backend()
selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend()
if selected_backend is not None:
return selected_backend