Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,6 @@ from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -19,7 +18,7 @@ from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
def get_env_variable_attn_backend() -> _Backend | None:
|
||||
"""
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
@@ -40,10 +39,10 @@ def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
#
|
||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
|
||||
forced_attn_backend: Optional[_Backend] = None
|
||||
forced_attn_backend: _Backend | None = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||
def global_force_attn_backend(attn_backend: _Backend | None) -> None:
|
||||
"""
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
@@ -58,7 +57,7 @@ def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
def get_global_forced_attn_backend() -> _Backend | None:
|
||||
"""
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
@@ -77,7 +76,7 @@ class _IsSupported:
|
||||
|
||||
|
||||
def is_attn_backend_supported(
|
||||
attn_backend: Union[str, type[AttentionBackend]],
|
||||
attn_backend: str | type[AttentionBackend],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
@@ -127,7 +126,7 @@ def is_attn_backend_supported(
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
@@ -154,7 +153,7 @@ def get_attn_backend(
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
@@ -167,12 +166,12 @@ def _cached_get_attn_backend(
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = get_global_forced_attn_backend()
|
||||
backend_by_global_setting: _Backend | None = get_global_forced_attn_backend()
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
if backend_by_env_var.endswith("_VLLM_V1"):
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user