[platform] Allow platform specify attention backend (#11609)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -9,7 +9,7 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -114,83 +114,19 @@ def _cached_get_attn_backend(
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
||||
is_attention_free, use_v1)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using Flash Attention backend.")
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
return FlashAttentionBackend
|
||||
if backend == _Backend.FLASH_ATTN_VLLM_V1:
|
||||
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend as FlashAttentionBackendV1)
|
||||
return FlashAttentionBackendV1
|
||||
if backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||
XFormersBackend)
|
||||
return XFormersBackend
|
||||
elif backend == _Backend.ROCM_FLASH:
|
||||
logger.info("Using ROCmFlashAttention backend.")
|
||||
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
|
||||
ROCmFlashAttentionBackend)
|
||||
return ROCmFlashAttentionBackend
|
||||
elif backend == _Backend.TORCH_SDPA:
|
||||
assert current_platform.is_cpu(), RuntimeError(
|
||||
"Torch SDPA backend is only used for the CPU device.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||
return TorchSDPABackend
|
||||
elif backend == _Backend.OPENVINO:
|
||||
logger.info("Using OpenVINO Attention backend.")
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
|
||||
return OpenVINOAttentionBackend
|
||||
elif backend == _Backend.IPEX:
|
||||
assert current_platform.is_xpu(), RuntimeError(
|
||||
"IPEX attention backend is only used for the XPU device.")
|
||||
logger.info("Using IPEX attention backend.")
|
||||
from vllm.attention.backends.ipex_attn import IpexAttnBackend
|
||||
return IpexAttnBackend
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
logger.info("Using Flashinfer backend.")
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend
|
||||
elif backend == _Backend.HPU_ATTN:
|
||||
logger.info("Using HPUAttention backend.")
|
||||
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
|
||||
return HPUAttentionBackend
|
||||
elif backend == _Backend.PALLAS:
|
||||
logger.info("Using Pallas backend.")
|
||||
from vllm.attention.backends.pallas import PallasAttentionBackend
|
||||
return PallasAttentionBackend
|
||||
elif backend == _Backend.NO_ATTENTION:
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
return PlaceholderAttentionBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
|
||||
def which_attn_to_use(head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
use_v1: bool = False) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
return _Backend.NO_ATTENTION
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
return PlaceholderAttentionBackend
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
@@ -201,64 +137,13 @@ def which_attn_to_use(head_size: int,
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
|
||||
# get device-specific default attn_backend
|
||||
default_backend = current_platform.get_default_attn_backend(
|
||||
selected_backend)
|
||||
if default_backend is not None:
|
||||
return default_backend
|
||||
|
||||
if use_v1:
|
||||
return _Backend.FLASH_ATTN_VLLM_V1
|
||||
|
||||
# FlashAttn in NVIDIA GPUs.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
if not current_platform.has_device_capability(80):
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||
"GPUs.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for dtype other than "
|
||||
"torch.float16 or torch.bfloat16.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
||||
logger.warning(
|
||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||
"better performance by setting environment variable "
|
||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif block_size % 16 != 0:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for block size not "
|
||||
"divisible by 16.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
# FlashAttn is valid for the model, checking if the package is installed.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
try:
|
||||
import vllm.vllm_flash_attn # noqa: F401
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
|
||||
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in supported_sizes:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for head size %d.",
|
||||
head_size)
|
||||
selected_backend = _Backend.XFORMERS
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend because the "
|
||||
"vllm.vllm_flash_attn package is not found. "
|
||||
"Make sure that vllm_flash_attn was built and installed "
|
||||
"(on by default).")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
return selected_backend
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
Reference in New Issue
Block a user