[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -4,14 +4,15 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import cast, get_args
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@@ -19,18 +20,18 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> _Backend | None:
|
||||
def get_env_variable_attn_backend() -> AttentionBackendEnum | None:
|
||||
"""
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
|
||||
Returns:
|
||||
|
||||
* _Backend enum value if an override is specified
|
||||
* AttentionBackendEnum value if an override is specified
|
||||
* None otherwise
|
||||
"""
|
||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||
return None if backend_name is None else backend_name_to_enum(backend_name)
|
||||
return None if backend_name is None else AttentionBackendEnum[backend_name]
|
||||
|
||||
|
||||
# Global state allows a particular choice of backend
|
||||
@@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None:
|
||||
#
|
||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
||||
forced_attn_backend: _Backend | None = None
|
||||
forced_attn_backend: AttentionBackendEnum | None = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: _Backend | None) -> None:
|
||||
def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None:
|
||||
"""
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
@@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None:
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> _Backend | None:
|
||||
def get_global_forced_attn_backend() -> AttentionBackendEnum | None:
|
||||
"""
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
@@ -66,78 +67,28 @@ def get_global_forced_attn_backend() -> _Backend | None:
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _IsSupported:
|
||||
can_import: bool
|
||||
head_size: bool
|
||||
dtype: bool
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.can_import and self.head_size and self.dtype
|
||||
|
||||
|
||||
def is_attn_backend_supported(
|
||||
attn_backend: str | type[AttentionBackend],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
allow_import_error: bool = True,
|
||||
) -> _IsSupported:
|
||||
if isinstance(attn_backend, str):
|
||||
try:
|
||||
attn_backend = resolve_obj_by_qualname(attn_backend)
|
||||
except ImportError:
|
||||
if not allow_import_error:
|
||||
raise
|
||||
|
||||
return _IsSupported(can_import=False, head_size=False, dtype=False)
|
||||
|
||||
assert isinstance(attn_backend, type)
|
||||
|
||||
# TODO: Update the interface once V0 is removed
|
||||
if get_supported_head_sizes := getattr(
|
||||
attn_backend, "get_supported_head_sizes", None
|
||||
):
|
||||
is_head_size_supported = head_size in get_supported_head_sizes()
|
||||
elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
|
||||
try:
|
||||
validate_head_size(head_size)
|
||||
is_head_size_supported = True
|
||||
except Exception:
|
||||
is_head_size_supported = False
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{attn_backend.__name__} does not support head size validation"
|
||||
)
|
||||
|
||||
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
|
||||
is_dtype_supported = dtype in get_supported_dtypes()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{attn_backend.__name__} does not support dtype validation"
|
||||
)
|
||||
|
||||
return _IsSupported(
|
||||
can_import=True,
|
||||
head_size=is_head_size_supported,
|
||||
dtype=is_dtype_supported,
|
||||
)
|
||||
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
block_size: int | None,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
|
||||
if kv_cache_dtype is not None:
|
||||
valid_cache_dtypes = get_args(CacheDType)
|
||||
assert kv_cache_dtype in valid_cache_dtypes, (
|
||||
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
|
||||
f"Valid values are: {valid_cache_dtypes}"
|
||||
)
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
|
||||
block_size=block_size,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
@@ -149,8 +100,8 @@ def get_attn_backend(
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int | None,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
@@ -161,7 +112,9 @@ def _cached_get_attn_backend(
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: _Backend | None = get_global_forced_attn_backend()
|
||||
backend_by_global_setting: AttentionBackendEnum | None = (
|
||||
get_global_forced_attn_backend()
|
||||
)
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
@@ -177,12 +130,13 @@ def _cached_get_attn_backend(
|
||||
STR_BACKEND_ENV_VAR,
|
||||
)
|
||||
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
try:
|
||||
selected_backend = AttentionBackendEnum[backend_by_env_var]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||
)
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. Valid "
|
||||
f"backends are: {list(AttentionBackendEnum.__members__.keys())}"
|
||||
) from e
|
||||
|
||||
# get device-specific attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
@@ -202,12 +156,26 @@ def _cached_get_attn_backend(
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
backend = resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
# Adjust kv cache layout if the selected backend requires a specific one
|
||||
required_layout = backend.get_required_kv_cache_layout()
|
||||
if required_layout is not None:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
set_kv_cache_layout(required_layout)
|
||||
logger.info(
|
||||
"Using %s KV cache layout for %s backend.",
|
||||
required_layout,
|
||||
backend.get_name(),
|
||||
)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: _Backend,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Globally force a vLLM attention backend override within a
|
||||
|
||||
Reference in New Issue
Block a user