Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,10 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Generator, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -29,12 +30,11 @@ def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
|
||||
loaded.
|
||||
"""
|
||||
assert backend_name is not None
|
||||
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
|
||||
None
|
||||
return _Backend[backend_name] if backend_name in _Backend.__members__ else None
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
"""
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
|
||||
@@ -42,10 +42,9 @@ def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
|
||||
* _Backend enum value if an override is specified
|
||||
* None otherwise
|
||||
'''
|
||||
"""
|
||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||
return (None
|
||||
if backend_name is None else backend_name_to_enum(backend_name))
|
||||
return None if backend_name is None else backend_name_to_enum(backend_name)
|
||||
|
||||
|
||||
# Global state allows a particular choice of backend
|
||||
@@ -59,7 +58,7 @@ forced_attn_backend: Optional[_Backend] = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||
'''
|
||||
"""
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
Passing `None` for the argument re-enables automatic
|
||||
@@ -68,16 +67,16 @@ def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||
Arguments:
|
||||
|
||||
* attn_backend: backend selection (None to revert to auto)
|
||||
'''
|
||||
"""
|
||||
global forced_attn_backend
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
"""
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
'''
|
||||
"""
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
@@ -110,26 +109,27 @@ def is_attn_backend_supported(
|
||||
assert isinstance(attn_backend, type)
|
||||
|
||||
# TODO: Update the interface once V0 is removed
|
||||
if get_supported_head_sizes := getattr(attn_backend,
|
||||
"get_supported_head_sizes", None):
|
||||
if get_supported_head_sizes := getattr(
|
||||
attn_backend, "get_supported_head_sizes", None
|
||||
):
|
||||
is_head_size_supported = head_size in get_supported_head_sizes()
|
||||
elif validate_head_size := getattr(attn_backend, "validate_head_size",
|
||||
None):
|
||||
elif validate_head_size := getattr(attn_backend, "validate_head_size", None):
|
||||
try:
|
||||
validate_head_size(head_size)
|
||||
is_head_size_supported = True
|
||||
except Exception:
|
||||
is_head_size_supported = False
|
||||
else:
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"head size validation")
|
||||
raise NotImplementedError(
|
||||
f"{attn_backend.__name__} does not support head size validation"
|
||||
)
|
||||
|
||||
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
|
||||
None):
|
||||
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None):
|
||||
is_dtype_supported = dtype in get_supported_dtypes()
|
||||
else:
|
||||
raise NotImplementedError(f"{attn_backend.__name__} does not support "
|
||||
"dtype validation")
|
||||
raise NotImplementedError(
|
||||
f"{attn_backend.__name__} does not support dtype validation"
|
||||
)
|
||||
|
||||
return _IsSupported(
|
||||
can_import=True,
|
||||
@@ -175,15 +175,13 @@ def _cached_get_attn_backend(
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
backend_by_global_setting: Optional[_Backend] = get_global_forced_attn_backend()
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
@@ -195,29 +193,41 @@ def _cached_get_attn_backend(
|
||||
"The suffix '_VLLM_V1' in the environment variable "
|
||||
"%s is no longer necessary as V0 backends have been "
|
||||
"deprecated. Please remove this suffix from your "
|
||||
"environment variable setting.", STR_BACKEND_ENV_VAR)
|
||||
backend_by_env_var = backend_by_env_var.removesuffix(
|
||||
"_VLLM_V1")
|
||||
"environment variable setting.",
|
||||
STR_BACKEND_ENV_VAR,
|
||||
)
|
||||
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}")
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||
)
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla, has_sink, use_sparse)
|
||||
selected_backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: _Backend) -> Generator[None, None, None]:
|
||||
'''
|
||||
attn_backend: _Backend,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Globally force a vLLM attention backend override within a
|
||||
context manager, reverting the global attention backend
|
||||
override to its prior state upon exiting the context
|
||||
@@ -230,7 +240,7 @@ def global_force_attn_backend_context_manager(
|
||||
Returns:
|
||||
|
||||
* Generator
|
||||
'''
|
||||
"""
|
||||
|
||||
# Save the current state of the global backend override (if any)
|
||||
original_value = get_global_forced_attn_backend()
|
||||
|
||||
Reference in New Issue
Block a user