[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -11,9 +11,9 @@ from pydantic.dataclasses import dataclass
|
||||
from vllm.config.utils import config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
else:
|
||||
_Backend = Any
|
||||
AttentionBackendEnum = Any
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -125,10 +125,10 @@ class MultiModalConfig:
|
||||
DP (which is controlled by `--data-parallel-size`).
|
||||
This is only supported on a per-model basis and falls back to
|
||||
`"weights"` if the encoder does not support DP."""
|
||||
mm_encoder_attn_backend: _Backend | None = None
|
||||
mm_encoder_attn_backend: AttentionBackendEnum | None = None
|
||||
"""Optional override for the multi-modal encoder attention backend when
|
||||
using vision transformers. Accepts any value from
|
||||
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`)."""
|
||||
`vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
|
||||
interleave_mm_strings: bool = False
|
||||
"""Enable fully interleaved support for multimodal prompts, while using
|
||||
--chat-template-content-format=string."""
|
||||
@@ -167,26 +167,16 @@ class MultiModalConfig:
|
||||
|
||||
@field_validator("mm_encoder_attn_backend", mode="before")
|
||||
@classmethod
|
||||
def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None:
|
||||
from vllm.attention.backends.registry import (
|
||||
_Backend as BackendEnum,
|
||||
)
|
||||
from vllm.attention.backends.registry import (
|
||||
backend_name_to_enum,
|
||||
)
|
||||
|
||||
if value is None or isinstance(value, BackendEnum):
|
||||
def _validate_mm_encoder_attn_backend(
|
||||
cls, value: str | AttentionBackendEnum | None
|
||||
) -> AttentionBackendEnum | None:
|
||||
if value is None or isinstance(value, AttentionBackendEnum):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
candidate = backend_name_to_enum(value.upper())
|
||||
if candidate is not None:
|
||||
return candidate
|
||||
|
||||
valid_backends = ", ".join(sorted(BackendEnum.__members__.keys()))
|
||||
raise ValueError(
|
||||
f"Invalid mm encoder attention backend. Expected one of: {valid_backends}."
|
||||
assert isinstance(value, str), (
|
||||
"mm_encoder_attn_backend must be a string or an AttentionBackendEnum."
|
||||
)
|
||||
return AttentionBackendEnum[value.upper()]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_multimodal_config(self):
|
||||
|
||||
Reference in New Issue
Block a user