diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index f9b771540..2c99d28fd 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -67,25 +67,29 @@ logger = init_logger(__name__) class FutureWrapper(Future): def __init__( self, - futures_queue: deque[tuple["FutureWrapper", Callable]], + futures_queue: deque["FutureWrapper"], + get_response: Callable[[], Any], aggregate: Callable = lambda x: x, ): self.futures_queue = futures_queue + self.get_response = get_response self.aggregate = aggregate super().__init__() + self.futures_queue.appendleft(self) def result(self, timeout=None): if timeout is not None: raise RuntimeError("timeout not implemented") + # Drain any futures ahead of us in the queue. while not self.done(): - future, get_response = self.futures_queue.pop() - future.wait_for_response(get_response) + future = self.futures_queue.pop() + future._wait_for_response() return super().result() - def wait_for_response(self, get_response: Callable): + def _wait_for_response(self): try: - response = self.aggregate(get_response()) + response = self.aggregate(self.get_response()) with suppress(InvalidStateError): self.set_result(response) except Exception as e: @@ -218,7 +222,7 @@ class MultiprocExecutor(Executor): for response_mq in self.response_mqs: response_mq.wait_until_ready() - self.futures_queue = deque[tuple[FutureWrapper, Callable]]() + self.futures_queue = deque[FutureWrapper]() self._post_init_executor() @@ -384,17 +388,13 @@ class MultiprocExecutor(Executor): responses.append(result) return responses[0] if output_rank is not None else responses - if non_block: - future = FutureWrapper(self.futures_queue, aggregate=aggregate) - self.futures_queue.appendleft((future, get_response)) - return future + future = FutureWrapper( + self.futures_queue, + get_response=get_response, + aggregate=aggregate, + ) - # First drain any pending futures in the queue. - while self.futures_queue: - future, get_fut_response = self.futures_queue.pop() - future.wait_for_response(get_fut_response) - - return aggregate(get_response()) + return future if non_block else future.result() @staticmethod def _ensure_worker_termination(worker_procs: list[BaseProcess]):