[Core] Multiprocessing Pipeline Parallel support (#6130)

Co-authored-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Nick Hill
2024-07-18 19:15:52 -07:00
committed by GitHub
parent c5df56f88b
commit b5672a112c
9 changed files with 152 additions and 99 deletions

View File

@@ -7,12 +7,13 @@ from typing import Any, List, Optional
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (cuda_device_count_stateless,
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
error_on_invalid_device_count_status,
get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async,
@@ -26,7 +27,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
# Create the parallel GPU workers.
world_size = self.parallel_config.tensor_parallel_size
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
@@ -49,8 +51,15 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if world_size > 1:
maybe_set_triton_cache_manager()
assert world_size <= cuda_device_count_stateless(), (
"please set tensor_parallel_size to less than max local gpu count")
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")
assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")
error_on_invalid_device_count_status()
@@ -60,21 +69,35 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())
self.workers: List[ProcessWorkerWrapper] = []
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[ProcessWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[ProcessWorkerWrapper] = []
if world_size == 1:
self.workers = []
self.worker_monitor = None
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
for rank in range(1, world_size):
worker = ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)) for rank in range(1, world_size)
]
create_worker,
**self._get_create_worker_kwargs(
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)))
self.workers.append(worker)
if rank % tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
@@ -136,16 +159,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the workers first.
if async_run_tensor_parallel_workers_only:
# Run only non-driver workers and just return futures.
return [
worker.execute_method(method, *args, **kwargs)
for worker in self.non_driver_workers
]
# Start all remote workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
if async_run_tensor_parallel_workers_only:
# Just return futures
return worker_outputs
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
@@ -172,16 +198,45 @@ class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
self.pp_locks: Optional[List[asyncio.Lock]] = None
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_model(execute_model_req)
if not self.tp_driver_workers:
return await self.driver_exec_model(execute_model_req)
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self.pp_locks = [
asyncio.Lock()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
tasks = [
asyncio.create_task(
_run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method_async,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method_async("start_worker_execution_loop")
for worker in self.workers
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)