[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import enum
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
@@ -9,26 +8,12 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
OPENVINO = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> _Backend:
|
||||
assert backend_name is not None
|
||||
|
||||
@@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int,
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
|
||||
if current_platform.is_cpu():
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
if current_platform.is_openvino():
|
||||
if selected_backend != _Backend.OPENVINO:
|
||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||
return _Backend.OPENVINO
|
||||
|
||||
if current_platform.is_xpu():
|
||||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
return _Backend.IPEX
|
||||
|
||||
if current_platform.is_tpu():
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
return _Backend.PALLAS
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# AMD GPUs.
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if not current_platform.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||
return _Backend.ROCM_FLASH
|
||||
|
||||
if current_platform.is_hpu():
|
||||
return _Backend.HPU_ATTN
|
||||
# get device-specific default attn_backend
|
||||
default_backend = current_platform.get_default_attn_backend(
|
||||
selected_backend)
|
||||
if default_backend is not None:
|
||||
return default_backend
|
||||
|
||||
if use_v1:
|
||||
return _Backend.FLASH_ATTN_VLLM_V1
|
||||
|
||||
Reference in New Issue
Block a user