[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -11,9 +11,9 @@ from pydantic.dataclasses import dataclass
from vllm.config.utils import config
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
else:
_Backend = Any
AttentionBackendEnum = Any
@dataclass
@@ -125,10 +125,10 @@ class MultiModalConfig:
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
mm_encoder_attn_backend: _Backend | None = None
mm_encoder_attn_backend: AttentionBackendEnum | None = None
"""Optional override for the multi-modal encoder attention backend when
using vision transformers. Accepts any value from
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`)."""
`vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string."""
@@ -167,26 +167,16 @@ class MultiModalConfig:
@field_validator("mm_encoder_attn_backend", mode="before")
@classmethod
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None:
from vllm.attention.backends.registry import (
_Backend as BackendEnum,
)
from vllm.attention.backends.registry import (
backend_name_to_enum,
)
if value is None or isinstance(value, BackendEnum):
def _validate_mm_encoder_attn_backend(
cls, value: str | AttentionBackendEnum | None
) -> AttentionBackendEnum | None:
if value is None or isinstance(value, AttentionBackendEnum):
return value
if isinstance(value, str):
candidate = backend_name_to_enum(value.upper())
if candidate is not None:
return candidate
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
raise ValueError(
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
assert isinstance(value, str), (
"mm_encoder_attn_backend must be a string or an AttentionBackendEnum."
)
return AttentionBackendEnum[value.upper()]
@model_validator(mode="after")
def _validate_multimodal_config(self):