[Core] Eliminate parallel worker per-step task scheduling overhead (#4894)

This commit is contained in:
Nick Hill
2024-05-22 14:17:27 -07:00
committed by GitHub
parent 97b030005c
commit eb6d3c264d
12 changed files with 350 additions and 211 deletions

View File

@@ -1,13 +1,14 @@
import asyncio
import os
from functools import partial
from typing import Any, Dict, Optional, Tuple
from typing import Any, List, Optional
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
@@ -71,16 +72,34 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
None)) is not None:
worker_monitor.close()
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(
execute_model_req=execute_model_req)
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
async_run_remote_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
if max_concurrent_workers:
raise NotImplementedError(
@@ -92,15 +111,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
for worker in self.workers
]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
if async_run_remote_workers_only:
# Just return futures
return worker_outputs
# Start the driver worker after all the ray workers.
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*driver_args,
**driver_kwargs)
driver_worker_output = driver_worker_method(*args, **kwargs)
# Get the results of the workers.
return [driver_worker_output
@@ -111,30 +127,29 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if not self.worker_monitor.is_alive():
raise RuntimeError("Worker processes are not running")
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
DistributedGPUExecutorAsync):
async def _run_workers_async(
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
async def _driver_execute_model_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_model(execute_model_req)
driver_executor = make_async(getattr(self.driver_worker, method))
# Run all the workers asynchronously.
coros = [driver_executor(*driver_args, **driver_kwargs)] + [
worker.execute_method_async(method, *args, **kwargs)
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method_async("start_worker_execution_loop")
for worker in self.workers
]
return await asyncio.gather(*coros)