[V1][PP] Support PP for MultiprocExecutor (#14219)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-05-06 22:58:05 +08:00
committed by GitHub
parent d419aa5dc4
commit a6fed02068
5 changed files with 98 additions and 28 deletions

View File

@@ -8,7 +8,7 @@ import threading
import time
import traceback
import weakref
from concurrent.futures import Future
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
@@ -53,10 +53,11 @@ class MultiprocExecutor(Executor):
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
assert self.world_size == tensor_parallel_size, (
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}). "
f"Pipeline parallelism is not yet implemented in v1")
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). ")
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
@@ -104,6 +105,17 @@ class MultiprocExecutor(Executor):
self._ensure_worker_termination(
[w.proc for w in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
self.output_rank = self._get_output_rank()
def start_worker_monitor(self):
workers = self.workers
self_ref = weakref.ref(self)
@@ -145,7 +157,9 @@ class MultiprocExecutor(Executor):
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc("execute_model",
args=(scheduler_output, ),
rank0_reply_only=True,
unique_reply_rank=self.output_rank,
non_block=self.max_concurrent_batches
> 1,
timeout=EXECUTE_MODEL_TIMEOUT_S)
return output
@@ -154,7 +168,8 @@ class MultiprocExecutor(Executor):
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
rank0_reply_only: bool = False) -> list[Any]:
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
if self.is_failed:
raise RuntimeError("Executor failed.")
@@ -171,22 +186,35 @@ class MultiprocExecutor(Executor):
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, rank0_reply_only))
(send_method, args, kwargs, unique_reply_rank))
workers = (self.workers[0], ) if rank0_reply_only else self.workers
responses = [None] * len(workers)
for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
workers = (self.workers[unique_reply_rank],
) if unique_reply_rank is not None else self.workers
responses = []
def get_response(w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None):
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=self.shutdown_event)
timeout=dequeue_timeout, cancel=cancel_event)
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause")
return result
responses[w.rank] = result
for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
if non_block:
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event)
else:
result = get_response(w, dequeue_timeout)
responses.append(result)
return responses
except TimeoutError as e:
@@ -225,6 +253,11 @@ class MultiprocExecutor(Executor):
if not getattr(self, 'shutting_down', False):
self.shutting_down = True
self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
self.io_thread_pool = None
for w in self.workers:
w.worker_response_mq = None
self._ensure_worker_termination([w.proc for w in self.workers])
@@ -235,6 +268,22 @@ class MultiprocExecutor(Executor):
self.collective_rpc("check_health", timeout=10)
return
@property
def max_concurrent_batches(self) -> int:
return self.parallel_config.pipeline_parallel_size
def _get_output_rank(self) -> int:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
# (the first TP worker of the last PP stage).
# Example:
# Assuming TP=8, PP=4, then the world_size=32
# 0-7, PP rank 0
# 8-15, PP rank 1
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size
@dataclass
class UnreadyWorkerProcHandle:
@@ -280,12 +329,14 @@ class WorkerProc:
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = (
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
all_kwargs[rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": rank == 0,
"is_driver_worker": is_driver_worker,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper
@@ -455,7 +506,7 @@ class WorkerProc:
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue()
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
try:
if isinstance(method, str):
@@ -470,11 +521,11 @@ class WorkerProc:
logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
if not rank0_only or self.rank == 0:
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue
if not rank0_only or self.rank == 0:
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))