[platforms] absorb worker cls difference into platforms folder (#10555)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
youkaichao
2024-11-21 21:00:32 -08:00
committed by GitHub
parent 446c7806b2
commit a111d0151f
21 changed files with 272 additions and 282 deletions

View File

@@ -1,7 +1,14 @@
from typing import TYPE_CHECKING
import torch
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
@@ -14,3 +21,19 @@ class HpuPlatform(Platform):
@staticmethod
def inference_mode():
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config = vllm_config.scheduler_config
if scheduler_config.is_multi_step:
raise NotImplementedError(
"Multi-step execution is not implemented for HPU")
if vllm_config.speculative_config is not None:
raise NotImplementedError(
"Speculative decoding is not implemented for HPU")
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"