[BugFix] Make PD work with Ray (#21072)

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
kourosh hakhamaneshi
2025-07-19 08:46:50 -07:00
committed by GitHub
parent 6a971ed692
commit 9f414a12ad
11 changed files with 330 additions and 222 deletions

View File

@@ -9,8 +9,7 @@ import threading
import time
import traceback
import weakref
from collections import defaultdict
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
@@ -27,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.executor.multiproc_worker_utils import (
_add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
@@ -118,13 +118,8 @@ class MultiprocExecutor(Executor):
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self._recv_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self._send_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self.kv_output_aggregator = KVOutputAggregator(
self.parallel_config.world_size)
def start_worker_monitor(self):
workers = self.workers
@@ -186,8 +181,9 @@ class MultiprocExecutor(Executor):
# aggregate all workers output to a single output
if non_block:
return self._async_aggregate_workers_output(outputs)
return self._aggregate_workers_output(outputs)
return self.kv_output_aggregator.async_aggregate(
outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def collective_rpc(self,
method: Union[str, Callable],
@@ -246,74 +242,6 @@ class MultiprocExecutor(Executor):
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
def _aggregate_workers_output(
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)
# select output of the worker specified by output_rank
output = outputs[self.output_rank]
# set the aggregated finished_sending / finished_recving
output.finished_sending = finished_sending if finished_sending else None
output.finished_recving = finished_recving if finished_recving else None
return output
def _async_aggregate_workers_output(
self, output_futures: list[Future[ModelRunnerOutput]]
) -> (Future[ModelRunnerOutput]):
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
outputs: list[Optional[ModelRunnerOutput]] = [None
] * len(output_futures)
def make_callback(idx):
def callback(fut):
if result_future.done():
return
try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self._aggregate_workers_output(
cast(list[ModelRunnerOutput], outputs)))
return callback
for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
return result_future
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have