[torch.compile] support all attention backends (#10558)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -20,6 +20,7 @@ logger = init_logger(__name__)
|
||||
class CpuPlatform(Platform):
|
||||
_enum = PlatformEnum.CPU
|
||||
device_type: str = "cpu"
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
|
||||
@@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
|
||||
class CudaPlatform(Platform):
|
||||
_enum = PlatformEnum.CUDA
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
|
||||
@@ -13,6 +13,7 @@ else:
|
||||
class HpuPlatform(Platform):
|
||||
_enum = PlatformEnum.HPU
|
||||
device_type: str = "hpu"
|
||||
dispatch_key: str = "HPU"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple):
|
||||
class Platform:
|
||||
_enum: PlatformEnum
|
||||
device_type: str
|
||||
# available dispatch keys:
|
||||
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
||||
# use "CPU" as a fallback for platforms not registered in PyTorch
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
def is_cuda(self) -> bool:
|
||||
return self._enum == PlatformEnum.CUDA
|
||||
|
||||
@@ -18,6 +18,7 @@ logger = init_logger(__name__)
|
||||
class OpenVinoPlatform(Platform):
|
||||
_enum = PlatformEnum.OPENVINO
|
||||
device_type: str = "openvino"
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@@ -17,6 +17,7 @@ logger = init_logger(__name__)
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
device_type: str = "tpu"
|
||||
dispatch_key: str = "XLA"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
@@ -17,6 +17,7 @@ logger = init_logger(__name__)
|
||||
class XPUPlatform(Platform):
|
||||
_enum = PlatformEnum.XPU
|
||||
device_type: str = "xpu"
|
||||
dispatch_key: str = "XPU"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
|
||||
Reference in New Issue
Block a user