[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2024-11-19 11:22:26 +08:00
committed by GitHub
parent 7eb719df13
commit 8c1fb50705
14 changed files with 99 additions and 69 deletions

View File

@@ -1,3 +1,4 @@
from .interface import _Backend # noqa: F401
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform

View File

@@ -5,7 +5,9 @@ import torch
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.config import VllmConfig
@@ -22,6 +24,12 @@ class CpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total

View File

@@ -1,11 +1,15 @@
import torch
from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
return _Backend.HPU_ATTN
@staticmethod
def inference_mode():
return torch.no_grad()

View File

@@ -11,6 +11,20 @@ else:
VllmConfig = None
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
@@ -71,6 +85,11 @@ class Platform:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend):
"""Get the default attention backend of a device."""
return None
@classmethod
def get_device_capability(
cls,

View File

@@ -3,7 +3,7 @@ import torch
import vllm.envs as envs
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
@@ -11,6 +11,12 @@ logger = init_logger(__name__)
class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
@classmethod
def get_device_name(self, device_id: int = 0) -> str:
return "openvino"

View File

@@ -5,7 +5,7 @@ import torch
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
@@ -19,6 +19,18 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
@classmethod
@lru_cache(maxsize=8)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:

View File

@@ -3,17 +3,27 @@ from typing import TYPE_CHECKING
import torch
from .interface import Platform, PlatformEnum
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError

View File

@@ -1,11 +1,21 @@
import torch
from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
class XPUPlatform(Platform):
_enum = PlatformEnum.XPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
@staticmethod
def get_device_capability(device_id: int = 0) -> DeviceCapability:
major, minor, *_ = torch.xpu.get_device_capability(