[KVConnector] Aggregate finished requests on the scheduler (#19555)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -9,7 +9,8 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
@@ -111,10 +112,19 @@ class MultiprocExecutor(Executor):
|
||||
if self.max_concurrent_batches > 1:
|
||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||
# from the response queue
|
||||
# _async_aggregate_workers_output also assumes a single IO thread
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io")
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
# Complete transfer tracker. Used by to track finished requests
|
||||
# [req_id -> n_finished_workers]
|
||||
self._recv_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
self._send_remaining_count = defaultdict[str,
|
||||
int](lambda: self.world_size)
|
||||
|
||||
def start_worker_monitor(self):
|
||||
workers = self.workers
|
||||
@@ -155,13 +165,29 @@ class MultiprocExecutor(Executor):
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
(output, ) = self.collective_rpc(
|
||||
non_block = self.max_concurrent_batches > 1
|
||||
|
||||
if not self.has_connector:
|
||||
# get output only from a single worker (output_rank)
|
||||
(output, ) = self.collective_rpc(
|
||||
"execute_model",
|
||||
args=(scheduler_output, ),
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
|
||||
return output
|
||||
|
||||
# get output from all workers
|
||||
outputs = self.collective_rpc(
|
||||
"execute_model",
|
||||
args=(scheduler_output, ),
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=self.max_concurrent_batches > 1,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
|
||||
return output
|
||||
|
||||
# aggregate all workers output to a single output
|
||||
if non_block:
|
||||
return self._async_aggregate_workers_output(outputs)
|
||||
return self._aggregate_workers_output(outputs)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
@@ -220,6 +246,80 @@ class MultiprocExecutor(Executor):
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
|
||||
def _aggregate_workers_output(
|
||||
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
|
||||
# aggregate finished_sending, finished_recving from all workers
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
for output in outputs:
|
||||
# update finished_sending
|
||||
for req_id in output.finished_sending or []:
|
||||
new_count = self._send_remaining_count[req_id] - 1
|
||||
if new_count == 0:
|
||||
# got response from all workers, report back to scheduler
|
||||
finished_sending.add(req_id)
|
||||
del self._send_remaining_count[req_id]
|
||||
else:
|
||||
self._send_remaining_count[req_id] = new_count
|
||||
|
||||
# update finished_recving
|
||||
for req_id in output.finished_recving or []:
|
||||
new_count = self._recv_remaining_count[req_id] - 1
|
||||
if new_count == 0:
|
||||
# got response from all workers, report back to scheduler
|
||||
finished_recving.add(req_id)
|
||||
del self._recv_remaining_count[req_id]
|
||||
else:
|
||||
self._recv_remaining_count[req_id] = new_count
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[self.output_rank]
|
||||
|
||||
# set the aggregated finished_sending / finished_recving
|
||||
if finished_sending:
|
||||
output.finished_sending = finished_sending
|
||||
if finished_recving:
|
||||
output.finished_recving = finished_recving
|
||||
|
||||
return output
|
||||
|
||||
def _async_aggregate_workers_output(
|
||||
self, output_futures: list[Future[ModelRunnerOutput]]
|
||||
) -> (Future[ModelRunnerOutput]):
|
||||
"""Takes a list of futures and returns a single future which resolves
|
||||
to the respective list of outputs."""
|
||||
result_future: Future[ModelRunnerOutput] = Future()
|
||||
|
||||
outputs: list[Optional[ModelRunnerOutput]] = [None
|
||||
] * len(output_futures)
|
||||
|
||||
def make_callback(idx):
|
||||
|
||||
def callback(fut):
|
||||
if result_future.done():
|
||||
return
|
||||
|
||||
try:
|
||||
outputs[idx] = fut.result()
|
||||
except CancelledError:
|
||||
result_future.cancel()
|
||||
except Exception as e:
|
||||
result_future.set_exception(e)
|
||||
|
||||
# this check assumes io_thread_pool uses a single thread
|
||||
if all(outputs):
|
||||
result_future.set_result(
|
||||
self._aggregate_workers_output(
|
||||
cast(list[ModelRunnerOutput], outputs)))
|
||||
|
||||
return callback
|
||||
|
||||
for i, output_future in enumerate(output_futures):
|
||||
output_future.add_done_callback(make_callback(i))
|
||||
|
||||
return result_future
|
||||
|
||||
@staticmethod
|
||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||
|
||||
Reference in New Issue
Block a user