[Refactor] Simplify FutureWrapper in MultiprocExecutor (#38644)
Signed-off-by: Yifan <yzong@redhat.com> Signed-off-by: Yifan Zong <yzong@redhat.com>
This commit is contained in:
@@ -67,25 +67,29 @@ logger = init_logger(__name__)
|
|||||||
class FutureWrapper(Future):
|
class FutureWrapper(Future):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
futures_queue: deque[tuple["FutureWrapper", Callable]],
|
futures_queue: deque["FutureWrapper"],
|
||||||
|
get_response: Callable[[], Any],
|
||||||
aggregate: Callable = lambda x: x,
|
aggregate: Callable = lambda x: x,
|
||||||
):
|
):
|
||||||
self.futures_queue = futures_queue
|
self.futures_queue = futures_queue
|
||||||
|
self.get_response = get_response
|
||||||
self.aggregate = aggregate
|
self.aggregate = aggregate
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.futures_queue.appendleft(self)
|
||||||
|
|
||||||
def result(self, timeout=None):
|
def result(self, timeout=None):
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
raise RuntimeError("timeout not implemented")
|
raise RuntimeError("timeout not implemented")
|
||||||
|
|
||||||
# Drain any futures ahead of us in the queue.
|
# Drain any futures ahead of us in the queue.
|
||||||
while not self.done():
|
while not self.done():
|
||||||
future, get_response = self.futures_queue.pop()
|
future = self.futures_queue.pop()
|
||||||
future.wait_for_response(get_response)
|
future._wait_for_response()
|
||||||
return super().result()
|
return super().result()
|
||||||
|
|
||||||
def wait_for_response(self, get_response: Callable):
|
def _wait_for_response(self):
|
||||||
try:
|
try:
|
||||||
response = self.aggregate(get_response())
|
response = self.aggregate(self.get_response())
|
||||||
with suppress(InvalidStateError):
|
with suppress(InvalidStateError):
|
||||||
self.set_result(response)
|
self.set_result(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -218,7 +222,7 @@ class MultiprocExecutor(Executor):
|
|||||||
for response_mq in self.response_mqs:
|
for response_mq in self.response_mqs:
|
||||||
response_mq.wait_until_ready()
|
response_mq.wait_until_ready()
|
||||||
|
|
||||||
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
self.futures_queue = deque[FutureWrapper]()
|
||||||
|
|
||||||
self._post_init_executor()
|
self._post_init_executor()
|
||||||
|
|
||||||
@@ -384,17 +388,13 @@ class MultiprocExecutor(Executor):
|
|||||||
responses.append(result)
|
responses.append(result)
|
||||||
return responses[0] if output_rank is not None else responses
|
return responses[0] if output_rank is not None else responses
|
||||||
|
|
||||||
if non_block:
|
future = FutureWrapper(
|
||||||
future = FutureWrapper(self.futures_queue, aggregate=aggregate)
|
self.futures_queue,
|
||||||
self.futures_queue.appendleft((future, get_response))
|
get_response=get_response,
|
||||||
return future
|
aggregate=aggregate,
|
||||||
|
)
|
||||||
|
|
||||||
# First drain any pending futures in the queue.
|
return future if non_block else future.result()
|
||||||
while self.futures_queue:
|
|
||||||
future, get_fut_response = self.futures_queue.pop()
|
|
||||||
future.wait_for_response(get_fut_response)
|
|
||||||
|
|
||||||
return aggregate(get_response())
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||||
|
|||||||
Reference in New Issue
Block a user