[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -4,14 +4,15 @@
import os
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache
from typing import cast, get_args
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.utils import STR_BACKEND_ENV_VAR
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -19,18 +20,18 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
def get_env_variable_attn_backend() -> _Backend | None:
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
"""
Get the backend override specified by the vLLM attention
backend environment variable, if one is specified.
Returns:
* _Backend enum value if an override is specified
* AttentionBackendEnum value if an override is specified
* None otherwise
"""
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return None if backend_name is None else backend_name_to_enum(backend_name)
return None if backend_name is None else AttentionBackendEnum[backend_name]
# Global state allows a particular choice of backend
@@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None:
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: _Backend | None = None
forced_attn_backend: AttentionBackendEnum | None = None
def global_force_attn_backend(attn_backend: _Backend | None) -> None:
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
"""
Force all attention operations to use a specified backend.
@@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None:
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> _Backend | None:
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
"""
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
@@ -66,78 +67,28 @@ def get_global_forced_attn_backend() -> _Backend | None:
return forced_attn_backend
@dataclass(frozen=True)
class _IsSupported:
can_import: bool
head_size: bool
dtype: bool
def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype
def is_attn_backend_supported(
attn_backend: str | type[AttentionBackend],
head_size: int,
dtype: torch.dtype,
*,
allow_import_error: bool = True,
) -> _IsSupported:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
if not allow_import_error:
raise
return _IsSupported(can_import=False, head_size=False, dtype=False)
assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(
attn_backend, "get_supported_head_sizes", None
):
is_head_size_supported = head_size in get_supported_head_sizes()
elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
try:
validate_head_size(head_size)
is_head_size_supported = True
except Exception:
is_head_size_supported = False
else:
raise NotImplementedError(
f"{attn_backend.__name__} does not support head size validation"
)
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
is_dtype_supported = dtype in get_supported_dtypes()
else:
raise NotImplementedError(
f"{attn_backend.__name__} does not support dtype validation"
)
return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
if kv_cache_dtype is not None:
valid_cache_dtypes = get_args(CacheDType)
assert kv_cache_dtype in valid_cache_dtypes, (
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
f"Valid values are: {valid_cache_dtypes}"
)
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
block_size=block_size,
use_mla=use_mla,
has_sink=has_sink,
@@ -149,8 +100,8 @@ def get_attn_backend(
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
kv_cache_dtype: CacheDType | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
@@ -161,7 +112,9 @@ def _cached_get_attn_backend(
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: _Backend | None = get_global_forced_attn_backend()
backend_by_global_setting: AttentionBackendEnum | None = (
get_global_forced_attn_backend()
)
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
@@ -177,12 +130,13 @@ def _cached_get_attn_backend(
STR_BACKEND_ENV_VAR,
)
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
try:
selected_backend = AttentionBackendEnum[backend_by_env_var]
except KeyError as e:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}"
)
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
) from e
# get device-specific attn_backend
from vllm.platforms import current_platform
@@ -202,12 +156,26 @@ def _cached_get_attn_backend(
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
return resolve_obj_by_qualname(attention_cls)
backend = resolve_obj_by_qualname(attention_cls)
# Adjust kv cache layout if the selected backend requires a specific one
required_layout = backend.get_required_kv_cache_layout()
if required_layout is not None:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout(required_layout)
logger.info(
"Using %s KV cache layout for %s backend.",
required_layout,
backend.get_name(),
)
return backend
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: _Backend,
attn_backend: AttentionBackendEnum,
) -> Generator[None, None, None]:
"""
Globally force a vLLM attention backend override within a