[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -1,11 +1,21 @@
|
||||
import torch
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUPlatform(Platform):
|
||||
_enum = PlatformEnum.XPU
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
return _Backend.IPEX
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
||||
major, minor, *_ = torch.xpu.get_device_capability(
|
||||
|
||||
Reference in New Issue
Block a user