[Core] Allow specifying custom Executor (#6557)

This commit is contained in:
Antoni Baum
2024-07-19 18:25:06 -07:00
committed by GitHub
parent 2e26564259
commit 7bd82002ae
22 changed files with 310 additions and 92 deletions

View File

@@ -35,6 +35,8 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayXPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def __init__(
self,
model_config: ModelConfig,
@@ -107,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
return num_gpu_blocks, num_cpu_blocks
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
return dict(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
@@ -124,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
@@ -137,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
@@ -337,7 +339,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.distributed_executor_backend == "ray"
assert self.parallel_config.use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.