From ca1ac1a4b44f46c53747f6792507ec1927ade617 Mon Sep 17 00:00:00 2001 From: Itay Alroy <75032521+itayalroy@users.noreply.github.com> Date: Fri, 20 Mar 2026 02:58:31 +0200 Subject: [PATCH] Fix DP coordinator ZMQ port TOCTOU (#37452) Signed-off-by: Itay Alroy --- vllm/utils/network_utils.py | 2 +- vllm/v1/engine/coordinator.py | 64 +++++++++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/vllm/utils/network_utils.py b/vllm/utils/network_utils.py index 6b940c92d..6152bb0b2 100644 --- a/vllm/utils/network_utils.py +++ b/vllm/utils/network_utils.py @@ -247,7 +247,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]: scheme = parsed.scheme host = parsed.hostname or "" - port = str(parsed.port or "") + port = "" if parsed.port is None else str(parsed.port) if host.startswith("[") and host.endswith("]"): host = host[1:-1] # Remove brackets for IPv6 address diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 28cd13758..8ebf976c5 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy import multiprocessing +import multiprocessing.connection import time import weakref @@ -10,7 +11,7 @@ import zmq from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils.network_utils import make_zmq_socket +from vllm.utils.network_utils import get_tcp_uri, make_zmq_socket from vllm.utils.system_utils import get_mp_context, set_process_title from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType from vllm.v1.serial_utils import MsgpackDecoder @@ -55,6 +56,25 @@ class DPCoordinator: request wave / running state changes. """ + def _wait_for_zmq_addrs(self, zmq_addr_pipe) -> tuple[str, str, str]: + try: + ready = multiprocessing.connection.wait( + [zmq_addr_pipe, self.proc.sentinel], timeout=30 + ) + if not ready: + raise RuntimeError( + "DP Coordinator process failed to report ZMQ addresses " + "during startup." + ) + try: + return zmq_addr_pipe.recv() + except EOFError: + raise RuntimeError( + "DP Coordinator process failed during startup." + ) from None + finally: + zmq_addr_pipe.close() + def __init__( self, parallel_config: ParallelConfig, enable_wave_coordination: bool = True ): @@ -66,18 +86,24 @@ class DPCoordinator: # Assume coordinator is colocated with front-end procs when not in # either external or hybrid DP LB mode. local_only = not parallel_config.local_engines_only - front_publish_address = get_engine_client_zmq_addr( - local_only=local_only, host=host - ) - local_only_eng = dp_size == parallel_config.data_parallel_size_local # NOTE(yongji): handling scaling from intra-node to inter-node if parallel_config.enable_elastic_ep: local_only_eng = False - back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) - back_output_address = get_engine_client_zmq_addr(local_only_eng, host) + + def bind_address(local_only: bool) -> str: + return ( + get_engine_client_zmq_addr(local_only=True, host=host) + if local_only + else get_tcp_uri(host, 0) + ) + + front_publish_address = bind_address(local_only) + back_publish_address = bind_address(local_only_eng) + back_output_address = bind_address(local_only_eng) context = get_mp_context() + parent_zmq_addr_pipe, child_zmq_addr_pipe = context.Pipe(duplex=False) self.proc: multiprocessing.Process = context.Process( target=DPCoordinatorProc.run_coordinator, name="VLLM_DP_Coordinator", @@ -86,11 +112,18 @@ class DPCoordinator: "front_publish_address": front_publish_address, "back_output_address": back_output_address, "back_publish_address": back_publish_address, + "zmq_addr_pipe": child_zmq_addr_pipe, "enable_wave_coordination": enable_wave_coordination, }, daemon=True, ) self.proc.start() + child_zmq_addr_pipe.close() + ( + front_publish_address, + back_output_address, + back_publish_address, + ) = self._wait_for_zmq_addrs(parent_zmq_addr_pipe) self.stats_publish_address = front_publish_address self.coord_in_address = back_publish_address @@ -136,6 +169,7 @@ class DPCoordinatorProc: front_publish_address: str, back_output_address: str, back_publish_address: str, + zmq_addr_pipe=None, min_stats_update_interval_ms: int = 100, enable_wave_coordination: bool = True, ): @@ -149,15 +183,20 @@ class DPCoordinatorProc: front_publish_address, back_output_address, back_publish_address, + zmq_addr_pipe, ) except KeyboardInterrupt: logger.info("DP Coordinator process exiting") + finally: + if zmq_addr_pipe is not None: + zmq_addr_pipe.close() def process_input_socket( self, front_publish_address: str, back_output_address: str, back_publish_address: str, + zmq_addr_pipe=None, ): decoder = MsgpackDecoder(EngineCoreOutputs) @@ -191,6 +230,17 @@ class DPCoordinatorProc: bind=True, ) as publish_back, ): + if zmq_addr_pipe is not None: + try: + zmq_addr_pipe.send( + ( + publish_front.getsockopt(zmq.LAST_ENDPOINT).decode(), + output_back.getsockopt(zmq.LAST_ENDPOINT).decode(), + publish_back.getsockopt(zmq.LAST_ENDPOINT).decode(), + ) + ) + finally: + zmq_addr_pipe.close() # Wait until all engines subscribe. for _ in self.engines: if publish_back.recv() != b"\x01":