[Core] Eliminate parallel worker per-step task scheduling overhead (#4894)
This commit is contained in:
@@ -42,6 +42,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
self.forward_dag = None
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
self.extra_execute_model_run_workers_kwargs[
|
||||
"use_ray_compiled_dag"] = True
|
||||
|
||||
def _configure_ray_workers_use_nsight(self,
|
||||
ray_remote_kwargs) -> Dict[str, Any]:
|
||||
@@ -171,23 +173,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
all_outputs = self._run_workers(
|
||||
"execute_model",
|
||||
driver_kwargs={"execute_model_req": execute_model_req},
|
||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||
def _driver_execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
return all_outputs[0]
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
return self.driver_worker.execute_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
use_dummy_driver: bool = False,
|
||||
@@ -198,9 +200,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
ways:
|
||||
|
||||
- async_run_remote_workers_only: If True the method will be run only
|
||||
in the remote workers, not the driver worker. It will also be
|
||||
run asynchronously and return a list of futures rather than blocking
|
||||
on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
- args/kwargs and driver_args/driver_kwargs: Driver worker has
|
||||
different args
|
||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||
individually
|
||||
"""
|
||||
@@ -209,11 +213,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args if all_args is None else all_args[0]
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||
|
||||
count = len(self.workers)
|
||||
all_worker_args = repeat(args, count) if all_args is None \
|
||||
else islice(all_args, 1, None)
|
||||
@@ -225,6 +224,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# input. TODO(sang): Fix it.
|
||||
assert self.forward_dag is not None
|
||||
output_channels = self.forward_dag.execute(1)
|
||||
ray_worker_outputs = []
|
||||
else:
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
@@ -234,6 +234,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
|
||||
if async_run_remote_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
driver_args = args if all_args is None else all_args[0]
|
||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
if not use_dummy_driver:
|
||||
driver_worker_output = self.driver_worker.execute_method(
|
||||
@@ -260,6 +267,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
return [driver_worker_output] + ray_worker_outputs
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
ray.get(parallel_worker_tasks)
|
||||
|
||||
def _compiled_ray_dag(self):
|
||||
import pkg_resources
|
||||
required_version = "2.9"
|
||||
@@ -303,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.driver_executor = make_async(self.driver_worker.execute_method)
|
||||
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||
|
||||
async def _run_workers_async(
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
coros = []
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
return await self.driver_exec_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
coros.append(
|
||||
self.driver_executor(method, *driver_args, **driver_kwargs))
|
||||
|
||||
# Run the ray workers asynchronously.
|
||||
for worker in self.workers:
|
||||
coros.append(worker.execute_method.remote(method, *args, **kwargs))
|
||||
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
return all_outputs
|
||||
async def _start_worker_execution_loop(self):
|
||||
coros = [
|
||||
worker.execute_method.remote("start_worker_execution_loop")
|
||||
for worker in self.workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
Reference in New Issue
Block a user