[Core] Introduce SPMD worker execution using Ray accelerated DAG (#6032)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Co-authored-by: Stephanie Wang <swang@cs.berkeley.edu>
This commit is contained in:
Rui Qiao
2024-07-17 22:27:09 -07:00
committed by GitHub
parent d25877dd9b
commit 61e592747c
8 changed files with 216 additions and 119 deletions

View File

@@ -1,11 +1,11 @@
import asyncio
import os
import pickle
from collections import defaultdict
from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
Tuple, Union)
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
@@ -30,7 +30,7 @@ logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayXPUExecutor(DistributedGPUExecutor):
@@ -72,10 +72,9 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
# Profile the memory usage and initialize the cache.
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
@@ -270,7 +269,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
@@ -293,26 +291,20 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_worker_output = []
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
@@ -324,36 +316,28 @@ class RayXPUExecutor(DistributedGPUExecutor):
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs
return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self):
def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
from packaging import version
required_version = version.parse("2.32")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.worker_use_ray
assert self.parallel_config.distributed_executor_backend == "ray"
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
@@ -363,7 +347,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
])
return forward_dag.experimental_compile()
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""