[Core] Allow specifying custom Executor (#6557)
This commit is contained in:
@@ -35,6 +35,8 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||
|
||||
class RayXPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
uses_ray: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
@@ -107,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||
return dict(
|
||||
worker_module_name="vllm.worker.xpu_worker",
|
||||
worker_class_name="XPUWorker",
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
@@ -124,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
@@ -137,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(
|
||||
worker_module_name="vllm.worker.xpu_worker",
|
||||
worker_class_name="XPUWorker",
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
worker_module_name="vllm.worker.xpu_worker",
|
||||
worker_class_name="XPUWorker",
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
@@ -337,7 +339,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
f"required, but found {current_version}")
|
||||
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||
assert self.parallel_config.use_ray
|
||||
|
||||
# Right now, compiled DAG requires at least 1 arg. We send
|
||||
# a dummy value for now. It will be fixed soon.
|
||||
|
||||
Reference in New Issue
Block a user