[PerfFix] Avoid separate thread for MP executor shm spin (#28012)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -9,8 +9,10 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from concurrent.futures import Future, InvalidStateError
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import cached_property, partial
|
||||
@@ -54,6 +56,30 @@ from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
|
||||
self.futures_queue = futures_queue
|
||||
super().__init__()
|
||||
|
||||
def result(self, timeout=None):
|
||||
if timeout is not None:
|
||||
raise RuntimeError("timeout not implemented")
|
||||
# Drain any futures ahead of us in the queue.
|
||||
while not self.done():
|
||||
future, get_response = self.futures_queue.pop()
|
||||
future.wait_for_response(get_response)
|
||||
return super().result()
|
||||
|
||||
def wait_for_response(self, get_response: Callable):
|
||||
try:
|
||||
response = get_response()
|
||||
with suppress(InvalidStateError):
|
||||
self.set_result(response)
|
||||
except Exception as e:
|
||||
with suppress(InvalidStateError):
|
||||
self.set_exception(e)
|
||||
|
||||
|
||||
class MultiprocExecutor(Executor):
|
||||
supports_pp: bool = True
|
||||
|
||||
@@ -64,7 +90,6 @@ class MultiprocExecutor(Executor):
|
||||
self.is_failed = False
|
||||
self.shutdown_event = threading.Event()
|
||||
self.failure_callback: FailureCallback | None = None
|
||||
self.io_thread_pool: ThreadPoolExecutor | None = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
@@ -132,12 +157,7 @@ class MultiprocExecutor(Executor):
|
||||
uw.death_writer.close()
|
||||
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
||||
|
||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||
# from the response queue.
|
||||
# _async_aggregate_workers_output also assumes a single IO thread.
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io"
|
||||
)
|
||||
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
@@ -195,14 +215,13 @@ class MultiprocExecutor(Executor):
|
||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||
if not self.has_connector:
|
||||
# get output only from a single worker (output_rank)
|
||||
(output,) = self.collective_rpc(
|
||||
return self.collective_rpc(
|
||||
method,
|
||||
args=args,
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||
)
|
||||
return output
|
||||
|
||||
# get output from all workers
|
||||
outputs = self.collective_rpc(
|
||||
@@ -223,12 +242,11 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
||||
outputs = self.collective_rpc(
|
||||
return self.collective_rpc(
|
||||
"take_draft_token_ids", unique_reply_rank=self.output_rank
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
def collective_rpc(
|
||||
def collective_rpc( # type: ignore[override]
|
||||
self,
|
||||
method: str | Callable,
|
||||
timeout: float | None = None,
|
||||
@@ -236,7 +254,9 @@ class MultiprocExecutor(Executor):
|
||||
kwargs: dict | None = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: int | None = None,
|
||||
) -> list[Any]:
|
||||
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||
"""Returns single result if unique_reply_rank is provided, otherwise list."""
|
||||
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
|
||||
@@ -246,63 +266,52 @@ class MultiprocExecutor(Executor):
|
||||
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
||||
# and unpack them in the method of every worker, because every worker
|
||||
# knows their own rank.
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
send_method = method
|
||||
else:
|
||||
send_method = cloudpickle.dumps(
|
||||
method, protocol=pickle.HIGHEST_PROTOCOL
|
||||
)
|
||||
self.rpc_broadcast_mq.enqueue(
|
||||
(send_method, args, kwargs, unique_reply_rank)
|
||||
)
|
||||
|
||||
workers = (
|
||||
(self.workers[unique_reply_rank],)
|
||||
if unique_reply_rank is not None
|
||||
else self.workers
|
||||
)
|
||||
if isinstance(method, str):
|
||||
send_method = method
|
||||
else:
|
||||
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
|
||||
|
||||
workers = (
|
||||
(self.workers[unique_reply_rank],)
|
||||
if unique_reply_rank is not None
|
||||
else self.workers
|
||||
)
|
||||
|
||||
shutdown_event = self.shutdown_event
|
||||
|
||||
def get_response():
|
||||
responses = []
|
||||
|
||||
def get_response(
|
||||
w: WorkerProcHandle,
|
||||
dequeue_timeout: float | None = None,
|
||||
cancel_event: threading.Event | None = None,
|
||||
):
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=cancel_event
|
||||
for w in workers:
|
||||
dequeue_timeout = (
|
||||
None if deadline is None else (deadline - time.monotonic())
|
||||
)
|
||||
|
||||
try:
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=shutdown_event
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||
raise RuntimeError(
|
||||
f"Worker failed with error '{result}', please check the"
|
||||
" stack trace above for the root cause"
|
||||
)
|
||||
return result
|
||||
|
||||
for w in workers:
|
||||
dequeue_timeout = (
|
||||
None if deadline is None else (deadline - time.monotonic())
|
||||
)
|
||||
|
||||
if self.io_thread_pool is not None:
|
||||
# We must consume worker_response_mq from a single thread.
|
||||
result = self.io_thread_pool.submit( # type: ignore
|
||||
get_response, w, dequeue_timeout, self.shutdown_event
|
||||
)
|
||||
if not non_block:
|
||||
result = result.result()
|
||||
elif not non_block:
|
||||
result = get_response(w, dequeue_timeout, self.shutdown_event)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"non_block can only be used when max_concurrent_batches > 1"
|
||||
)
|
||||
responses.append(result)
|
||||
return responses[0] if unique_reply_rank is not None else responses
|
||||
|
||||
return responses
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
if non_block:
|
||||
future = FutureWrapper(self.futures_queue)
|
||||
self.futures_queue.appendleft((future, get_response))
|
||||
return future
|
||||
|
||||
# First drain any pending futures in the queue.
|
||||
while self.futures_queue:
|
||||
future, get_fut_response = self.futures_queue.pop()
|
||||
future.wait_for_response(get_fut_response)
|
||||
|
||||
return get_response()
|
||||
|
||||
@staticmethod
|
||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||
@@ -348,9 +357,6 @@ class MultiprocExecutor(Executor):
|
||||
self._ensure_worker_termination([w.proc for w in workers])
|
||||
|
||||
self.shutdown_event.set()
|
||||
if self.io_thread_pool is not None:
|
||||
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
|
||||
del self.io_thread_pool
|
||||
|
||||
self.rpc_broadcast_mq = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user