[Core] Eliminate parallel worker per-step task scheduling overhead (#4894)

This commit is contained in:
Nick Hill
2024-05-22 14:17:27 -07:00
committed by GitHub
parent 97b030005c
commit eb6d3c264d
12 changed files with 350 additions and 211 deletions

View File

@@ -42,6 +42,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
self.extra_execute_model_run_workers_kwargs[
"use_ray_compiled_dag"] = True
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
@@ -171,23 +173,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={"execute_model_req": execute_model_req},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
# Only the driver worker returns the sampling results.
return all_outputs[0]
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
async_run_remote_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
@@ -198,9 +200,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
"""Runs the given method on all workers. Can be used in the following
ways:
- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
@@ -209,11 +213,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
@@ -225,6 +224,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
ray_worker_outputs = []
else:
# Start the ray workers first.
ray_worker_outputs = [
@@ -234,6 +234,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only:
# Just return futures
return ray_worker_outputs
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
@@ -260,6 +267,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
return [driver_worker_output] + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self):
import pkg_resources
required_version = "2.9"
@@ -303,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method)
self.driver_exec_method = make_async(self.driver_worker.execute_method)
async def _run_workers_async(
async def _driver_execute_model_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req)
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
coros.append(
self.driver_executor(method, *driver_args, **driver_kwargs))
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
return all_outputs
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)