[platforms] absorb worker cls difference into platforms folder (#10555)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -115,13 +115,8 @@ class CPUExecutor(ExecutorBase):
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
):
|
||||
worker_module_name = "vllm.worker.cpu_worker"
|
||||
worker_class_name = "CPUWorker"
|
||||
|
||||
wrapper = WorkerWrapperBase(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
)
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
|
||||
assert self.distributed_init_method is not None
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
@@ -8,19 +8,14 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_worker(worker_module_name: str, worker_class_name: str,
|
||||
worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
|
||||
**kwargs):
|
||||
wrapper = WorkerWrapperBase(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
)
|
||||
def create_worker(**kwargs):
|
||||
vllm_config = kwargs.get("vllm_config")
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
|
||||
wrapper.init_worker(**kwargs)
|
||||
return wrapper.worker
|
||||
|
||||
@@ -57,43 +52,11 @@ class GPUExecutor(ExecutorBase):
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
|
||||
def _get_worker_module_and_class(
|
||||
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
|
||||
worker_class_fn = None
|
||||
if self.scheduler_config.is_multi_step:
|
||||
worker_module_name = "vllm.worker.multi_step_worker"
|
||||
worker_class_name = "MultiStepWorker"
|
||||
elif self.speculative_config:
|
||||
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
||||
worker_class_name = "create_spec_worker"
|
||||
else:
|
||||
worker_module_name = "vllm.worker.worker"
|
||||
worker_class_name = "Worker"
|
||||
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||
|
||||
def _get_create_worker_kwargs(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None) -> Dict:
|
||||
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method)
|
||||
|
||||
(worker_module_name, worker_class_name,
|
||||
worker_class_fn) = self._get_worker_module_and_class()
|
||||
worker_kwargs.update(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
)
|
||||
|
||||
return worker_kwargs
|
||||
|
||||
def _create_worker(self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None):
|
||||
return create_worker(**self._get_create_worker_kwargs(
|
||||
return create_worker(**self._get_worker_kwargs(
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method))
|
||||
|
||||
@@ -48,10 +48,7 @@ class HPUExecutor(ExecutorBase):
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None):
|
||||
wrapper = WorkerWrapperBase(
|
||||
worker_module_name="vllm.worker.hpu_worker",
|
||||
worker_class_name="HPUWorker",
|
||||
)
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method))
|
||||
return wrapper.worker
|
||||
|
||||
@@ -90,7 +90,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
result_handler,
|
||||
partial(
|
||||
create_worker,
|
||||
**self._get_create_worker_kwargs(
|
||||
**self._get_worker_kwargs(
|
||||
rank=rank,
|
||||
local_rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
||||
@@ -7,6 +7,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -25,10 +26,10 @@ class NeuronExecutor(ExecutorBase):
|
||||
self._init_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
from vllm.worker.neuron_worker import NeuronWorker
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = NeuronWorker(
|
||||
self.driver_worker = wrapper.init_worker(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
|
||||
get_open_port, make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -38,15 +39,12 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
self._init_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
from vllm.worker.openvino_worker import OpenVINOWorker
|
||||
|
||||
assert (
|
||||
self.parallel_config.world_size == 1
|
||||
), "OpenVINOExecutor only supports single CPU socket currently."
|
||||
wrapper = WorkerWrapperBase(vllm_config=self.vllm_config)
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = OpenVINOWorker(
|
||||
self.driver_worker = wrapper.init_worker(
|
||||
ov_core=self.ov_core,
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=0,
|
||||
|
||||
@@ -91,17 +91,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
return ray_remote_kwargs
|
||||
|
||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||
(worker_module_name, worker_class_name,
|
||||
worker_class_fn) = self._get_worker_module_and_class()
|
||||
|
||||
return dict(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
# child class could overwrite this to return actual env vars.
|
||||
def _get_env_vars_to_be_updated(self):
|
||||
return self._env_vars_for_all_workers
|
||||
@@ -135,7 +124,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
@@ -150,7 +138,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
@@ -161,7 +149,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
**worker_wrapper_kwargs)
|
||||
vllm_config=self.vllm_config)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
@@ -2,8 +2,7 @@ import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Type)
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import msgspec
|
||||
|
||||
@@ -18,7 +17,6 @@ from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, get_vllm_instance_id,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
@@ -81,33 +79,6 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
||||
def finish_measurements(self):
|
||||
self._run_workers("finish_measurements")
|
||||
|
||||
def _get_worker_module_and_class(
|
||||
self
|
||||
) -> Tuple[str, str, Optional[Callable[[],
|
||||
Type[WorkerBase]]]]: # noqa: F821
|
||||
worker_class_fn = None
|
||||
if self.scheduler_config.is_multi_step:
|
||||
raise NotImplementedError(
|
||||
"Multi-step execution is not implemented for HPU")
|
||||
elif self.speculative_config:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not implemented for HPU")
|
||||
else:
|
||||
worker_module_name = "vllm.worker.hpu_worker"
|
||||
worker_class_name = "HPUWorker"
|
||||
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||
|
||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||
(worker_module_name, worker_class_name,
|
||||
worker_class_fn) = self._get_worker_module_and_class()
|
||||
|
||||
return dict(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
worker_class_fn=worker_class_fn,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
# Otherwise, the ray workers are allocated with a full GPU.
|
||||
@@ -128,7 +99,6 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("HPU", 0):
|
||||
continue
|
||||
@@ -144,7 +114,7 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
||||
resources={'HPU': num_gpus},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
@@ -155,7 +125,7 @@ class RayHPUExecutor(DistributedGPUExecutor):
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
**worker_wrapper_kwargs)
|
||||
vllm_config=self.vllm_config)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
@@ -69,14 +69,6 @@ class RayTPUExecutor(TPUExecutor):
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
assert self.speculative_config is None
|
||||
if self.scheduler_config.is_multi_step:
|
||||
worker_module_name = "vllm.worker.multi_step_tpu_worker"
|
||||
worker_class_name = "MultiStepTPUWorker"
|
||||
else:
|
||||
worker_module_name = "vllm.worker.tpu_worker"
|
||||
worker_class_name = "TPUWorker"
|
||||
|
||||
# GKE does not fetch environment information from metadata server
|
||||
# and instead sets these from within the Ray process. Therefore we
|
||||
# need to override the Ray environment variables manually.
|
||||
@@ -95,11 +87,7 @@ class RayTPUExecutor(TPUExecutor):
|
||||
resources={"TPU": 1},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config)
|
||||
if override_env:
|
||||
worker.override_env_vars.remote(override_env)
|
||||
|
||||
@@ -109,10 +97,7 @@ class RayTPUExecutor(TPUExecutor):
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
vllm_config=self.vllm_config)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, List, Optional, Tuple, Type, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
@@ -6,7 +6,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import make_async
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -22,17 +21,6 @@ class XPUExecutor(GPUExecutor):
|
||||
|
||||
GPUExecutor._init_executor(self)
|
||||
|
||||
def _get_worker_module_and_class(
|
||||
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
|
||||
worker_class_fn = None
|
||||
if self.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"XPU does not support speculative decoding")
|
||||
else:
|
||||
worker_module_name = "vllm.worker.xpu_worker"
|
||||
worker_class_name = "XPUWorker"
|
||||
return (worker_module_name, worker_class_name, worker_class_fn)
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||
|
||||
Reference in New Issue
Block a user