[KVConnector] Aggregate finished requests on the scheduler (#19555)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2025-07-10 11:22:18 +03:00
committed by GitHub
parent fdfd409f8f
commit cc876d0f29
5 changed files with 139 additions and 110 deletions

View File

@@ -9,7 +9,8 @@ import threading
import time
import traceback
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from collections import defaultdict
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
@@ -111,10 +112,19 @@ class MultiprocExecutor(Executor):
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self._recv_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self._send_remaining_count = defaultdict[str,
int](lambda: self.world_size)
def start_worker_monitor(self):
workers = self.workers
@@ -155,13 +165,29 @@ class MultiprocExecutor(Executor):
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc(
non_block = self.max_concurrent_batches > 1
if not self.has_connector:
# get output only from a single worker (output_rank)
(output, ) = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
return output
# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
unique_reply_rank=self.output_rank,
non_block=self.max_concurrent_batches > 1,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
return output
# aggregate all workers output to a single output
if non_block:
return self._async_aggregate_workers_output(outputs)
return self._aggregate_workers_output(outputs)
def collective_rpc(self,
method: Union[str, Callable],
@@ -220,6 +246,80 @@ class MultiprocExecutor(Executor):
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
def _aggregate_workers_output(
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
# update finished_sending
for req_id in output.finished_sending or []:
new_count = self._send_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_sending.add(req_id)
del self._send_remaining_count[req_id]
else:
self._send_remaining_count[req_id] = new_count
# update finished_recving
for req_id in output.finished_recving or []:
new_count = self._recv_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_recving.add(req_id)
del self._recv_remaining_count[req_id]
else:
self._recv_remaining_count[req_id] = new_count
# select output of the worker specified by output_rank
output = outputs[self.output_rank]
# set the aggregated finished_sending / finished_recving
if finished_sending:
output.finished_sending = finished_sending
if finished_recving:
output.finished_recving = finished_recving
return output
def _async_aggregate_workers_output(
self, output_futures: list[Future[ModelRunnerOutput]]
) -> (Future[ModelRunnerOutput]):
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
outputs: list[Optional[ModelRunnerOutput]] = [None
] * len(output_futures)
def make_callback(idx):
def callback(fut):
if result_future.done():
return
try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self._aggregate_workers_output(
cast(list[ModelRunnerOutput], outputs)))
return callback
for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
return result_future
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have