[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2024-11-19 11:22:26 +08:00
committed by GitHub
parent 7eb719df13
commit 8c1fb50705
14 changed files with 99 additions and 69 deletions

View File

@@ -1,4 +1,3 @@
import enum
import os
from contextlib import contextmanager
from functools import lru_cache
@@ -9,26 +8,12 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms import _Backend, current_platform
from vllm.utils import STR_BACKEND_ENV_VAR
logger = init_logger(__name__)
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None
@@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int,
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if current_platform.is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
if current_platform.is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
if current_platform.is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if current_platform.is_rocm():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not current_platform.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
if current_platform.is_hpu():
return _Backend.HPU_ATTN
# get device-specific default attn_backend
default_backend = current_platform.get_default_attn_backend(
selected_backend)
if default_backend is not None:
return default_backend
if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1