[Refactor] Simplify FutureWrapper in MultiprocExecutor (#38644)
Signed-off-by: Yifan <yzong@redhat.com> Signed-off-by: Yifan Zong <yzong@redhat.com>
This commit is contained in:
@@ -67,25 +67,29 @@ logger = init_logger(__name__)
|
||||
class FutureWrapper(Future):
|
||||
def __init__(
|
||||
self,
|
||||
futures_queue: deque[tuple["FutureWrapper", Callable]],
|
||||
futures_queue: deque["FutureWrapper"],
|
||||
get_response: Callable[[], Any],
|
||||
aggregate: Callable = lambda x: x,
|
||||
):
|
||||
self.futures_queue = futures_queue
|
||||
self.get_response = get_response
|
||||
self.aggregate = aggregate
|
||||
super().__init__()
|
||||
self.futures_queue.appendleft(self)
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise RuntimeError("timeout not implemented")
|
||||
|
||||
# Drain any futures ahead of us in the queue.
|
||||
while not self.done():
|
||||
future, get_response = self.futures_queue.pop()
|
||||
future.wait_for_response(get_response)
|
||||
future = self.futures_queue.pop()
|
||||
future._wait_for_response()
|
||||
return super().result()
|
||||
|
||||
def wait_for_response(self, get_response: Callable):
|
||||
def _wait_for_response(self):
|
||||
try:
|
||||
response = self.aggregate(get_response())
|
||||
response = self.aggregate(self.get_response())
|
||||
with suppress(InvalidStateError):
|
||||
self.set_result(response)
|
||||
except Exception as e:
|
||||
@@ -218,7 +222,7 @@ class MultiprocExecutor(Executor):
|
||||
for response_mq in self.response_mqs:
|
||||
response_mq.wait_until_ready()
|
||||
|
||||
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
||||
self.futures_queue = deque[FutureWrapper]()
|
||||
|
||||
self._post_init_executor()
|
||||
|
||||
@@ -384,17 +388,13 @@ class MultiprocExecutor(Executor):
|
||||
responses.append(result)
|
||||
return responses[0] if output_rank is not None else responses
|
||||
|
||||
if non_block:
|
||||
future = FutureWrapper(self.futures_queue, aggregate=aggregate)
|
||||
self.futures_queue.appendleft((future, get_response))
|
||||
return future
|
||||
future = FutureWrapper(
|
||||
self.futures_queue,
|
||||
get_response=get_response,
|
||||
aggregate=aggregate,
|
||||
)
|
||||
|
||||
# First drain any pending futures in the queue.
|
||||
while self.futures_queue:
|
||||
future, get_fut_response = self.futures_queue.pop()
|
||||
future.wait_for_response(get_fut_response)
|
||||
|
||||
return aggregate(get_response())
|
||||
return future if non_block else future.result()
|
||||
|
||||
@staticmethod
|
||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||
|
||||
Reference in New Issue
Block a user