[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import enum
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
from typing import Generator, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -8,7 +10,8 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino,
|
||||
is_tpu, is_xpu)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -24,6 +27,66 @@ class _Backend(enum.Enum):
|
||||
IPEX = enum.auto()
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> _Backend:
|
||||
assert backend_name is not None
|
||||
|
||||
backend_members = _Backend.__members__
|
||||
if backend_name not in backend_members:
|
||||
raise ValueError(f"Invalid attention backend '{backend_name}'. "
|
||||
f"Available backends: {', '.join(backend_members)} "
|
||||
"(case-sensitive).")
|
||||
|
||||
return _Backend[backend_name]
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the backend override specified by the vLLM attention
|
||||
backend environment variable, if one is specified.
|
||||
|
||||
Returns:
|
||||
|
||||
* _Backend enum value if an override is specified
|
||||
* None otherwise
|
||||
'''
|
||||
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
|
||||
return (None
|
||||
if backend_name is None else backend_name_to_enum(backend_name))
|
||||
|
||||
|
||||
# Global state allows a particular choice of backend
|
||||
# to be forced, overriding the logic which auto-selects
|
||||
# a backend based on system & workload configuration
|
||||
# (default behavior if this variable is None)
|
||||
#
|
||||
# THIS SELECTION TAKES PRECEDENCE OVER THE
|
||||
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
|
||||
forced_attn_backend: Optional[_Backend] = None
|
||||
|
||||
|
||||
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
|
||||
'''
|
||||
Force all attention operations to use a specified backend.
|
||||
|
||||
Passing `None` for the argument re-enables automatic
|
||||
backend selection.,
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: backend selection (None to revert to auto)
|
||||
'''
|
||||
global forced_attn_backend
|
||||
forced_attn_backend = attn_backend
|
||||
|
||||
|
||||
def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
'''
|
||||
Get the currently-forced choice of attention backend,
|
||||
or None if auto-selection is currently enabled.
|
||||
'''
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_attn_backend(
|
||||
num_heads: int,
|
||||
@@ -101,16 +164,20 @@ def which_attn_to_use(
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
backend_members = _Backend.__members__
|
||||
if backend_by_env_var not in backend_members:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend '{backend_by_env_var}'. "
|
||||
f"Available backends: {', '.join(backend_members)} "
|
||||
"(case-sensitive).")
|
||||
selected_backend = _Backend[backend_by_env_var]
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
|
||||
if is_cpu():
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
@@ -193,3 +260,35 @@ def which_attn_to_use(
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
return selected_backend
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: _Backend) -> Generator[None, None, None]:
|
||||
'''
|
||||
Globally force a vLLM attention backend override within a
|
||||
context manager, reverting the global attention backend
|
||||
override to its prior state upon exiting the context
|
||||
manager.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_backend: attention backend to force
|
||||
|
||||
Returns:
|
||||
|
||||
* Generator
|
||||
'''
|
||||
|
||||
# Save the current state of the global backend override (if any)
|
||||
original_value = get_global_forced_attn_backend()
|
||||
|
||||
# Globally force the new backend override
|
||||
global_force_attn_backend(attn_backend)
|
||||
|
||||
# Yield control back to the enclosed code block
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Revert the original global backend override, if any
|
||||
global_force_attn_backend(original_value)
|
||||
|
||||
Reference in New Issue
Block a user