[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .interface import _Backend # noqa: F401
|
||||
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
|
||||
|
||||
current_platform: Platform
|
||||
|
||||
@@ -5,7 +5,9 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@@ -22,6 +24,12 @@ class CpuPlatform(Platform):
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
|
||||
class HpuPlatform(Platform):
|
||||
_enum = PlatformEnum.HPU
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
return _Backend.HPU_ATTN
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
|
||||
@@ -11,6 +11,20 @@ else:
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
OPENVINO = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
CUDA = enum.auto()
|
||||
ROCM = enum.auto()
|
||||
@@ -71,6 +85,11 @@ class Platform:
|
||||
"""Stateless version of :func:`torch.cuda.is_available`."""
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend):
|
||||
"""Get the default attention backend of a device."""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
cls,
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -11,6 +11,12 @@ logger = init_logger(__name__)
|
||||
class OpenVinoPlatform(Platform):
|
||||
_enum = PlatformEnum.OPENVINO
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
if selected_backend != _Backend.OPENVINO:
|
||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||
return _Backend.OPENVINO
|
||||
|
||||
@classmethod
|
||||
def get_device_name(self, device_id: int = 0) -> str:
|
||||
return "openvino"
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -19,6 +19,18 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if not cls.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||
return _Backend.ROCM_FLASH
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
|
||||
@@ -3,17 +3,27 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
return _Backend.PALLAS
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
import torch
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUPlatform(Platform):
|
||||
_enum = PlatformEnum.XPU
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
return _Backend.IPEX
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
||||
major, minor, *_ = torch.xpu.get_device_capability(
|
||||
|
||||
Reference in New Issue
Block a user