[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2024-11-19 11:22:26 +08:00
committed by GitHub
parent 7eb719df13
commit 8c1fb50705
14 changed files with 99 additions and 69 deletions

View File

@@ -1,11 +1,21 @@
import torch
from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
logger = init_logger(__name__)
class XPUPlatform(Platform):
_enum = PlatformEnum.XPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
@staticmethod
def get_device_capability(device_id: int = 0) -> DeviceCapability:
major, minor, *_ = torch.xpu.get_device_capability(