[Core] Allow specifying custom Executor (#6557)
This commit is contained in:
@@ -26,6 +26,8 @@ logger = init_logger(__name__)
|
||||
|
||||
class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
uses_ray: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
@@ -47,7 +49,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
|
||||
"VLLM_USE_RAY_COMPILED_DAG=1")
|
||||
|
||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||
assert self.uses_ray
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
@@ -75,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
return ray_remote_kwargs
|
||||
|
||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||
if self.speculative_config is not None:
|
||||
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
||||
worker_class_name = "create_spec_worker"
|
||||
else:
|
||||
worker_module_name = "vllm.worker.worker"
|
||||
worker_class_name = "Worker"
|
||||
|
||||
return dict(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if (self.parallel_config.tensor_parallel_size == 1
|
||||
@@ -97,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
@@ -106,23 +123,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
if self.speculative_config is not None:
|
||||
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
||||
worker_class_name = "create_spec_worker"
|
||||
else:
|
||||
worker_module_name = "vllm.worker.worker"
|
||||
worker_class_name = "Worker"
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
@@ -133,10 +139,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
**worker_wrapper_kwargs)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
@@ -378,7 +381,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
f"required, but found {current_version}")
|
||||
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||
assert self.parallel_config.use_ray
|
||||
|
||||
# Right now, compiled DAG requires at least 1 arg. We send
|
||||
# a dummy value for now. It will be fixed soon.
|
||||
|
||||
Reference in New Issue
Block a user