[Refactor] Simplify FutureWrapper in MultiprocExecutor (#38644)

Signed-off-by: Yifan <yzong@redhat.com>
Signed-off-by: Yifan Zong <yzong@redhat.com>
This commit is contained in:
yzong-rh
2026-04-01 17:28:26 -04:00
committed by GitHub
parent 6183cae1bd
commit cb268e4e55

View File

@@ -67,25 +67,29 @@ logger = init_logger(__name__)
class FutureWrapper(Future):
def __init__(
self,
futures_queue: deque[tuple["FutureWrapper", Callable]],
futures_queue: deque["FutureWrapper"],
get_response: Callable[[], Any],
aggregate: Callable = lambda x: x,
):
self.futures_queue = futures_queue
self.get_response = get_response
self.aggregate = aggregate
super().__init__()
self.futures_queue.appendleft(self)
def result(self, timeout=None):
if timeout is not None:
raise RuntimeError("timeout not implemented")
# Drain any futures ahead of us in the queue.
while not self.done():
future, get_response = self.futures_queue.pop()
future.wait_for_response(get_response)
future = self.futures_queue.pop()
future._wait_for_response()
return super().result()
def wait_for_response(self, get_response: Callable):
def _wait_for_response(self):
try:
response = self.aggregate(get_response())
response = self.aggregate(self.get_response())
with suppress(InvalidStateError):
self.set_result(response)
except Exception as e:
@@ -218,7 +222,7 @@ class MultiprocExecutor(Executor):
for response_mq in self.response_mqs:
response_mq.wait_until_ready()
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
self.futures_queue = deque[FutureWrapper]()
self._post_init_executor()
@@ -384,17 +388,13 @@ class MultiprocExecutor(Executor):
responses.append(result)
return responses[0] if output_rank is not None else responses
if non_block:
future = FutureWrapper(self.futures_queue, aggregate=aggregate)
self.futures_queue.appendleft((future, get_response))
return future
future = FutureWrapper(
self.futures_queue,
get_response=get_response,
aggregate=aggregate,
)
# First drain any pending futures in the queue.
while self.futures_queue:
future, get_fut_response = self.futures_queue.pop()
future.wait_for_response(get_fut_response)
return aggregate(get_response())
return future if non_block else future.result()
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):