Fix DP coordinator ZMQ port TOCTOU (#37452)

Signed-off-by: Itay Alroy <ialroy@nvidia.com>
This commit is contained in:
Itay Alroy
2026-03-20 02:58:31 +02:00
committed by GitHub
parent 4ca3fa6bb4
commit ca1ac1a4b4
2 changed files with 58 additions and 8 deletions

View File

@@ -247,7 +247,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]:
scheme = parsed.scheme
host = parsed.hostname or ""
port = str(parsed.port or "")
port = "" if parsed.port is None else str(parsed.port)
if host.startswith("[") and host.endswith("]"):
host = host[1:-1] # Remove brackets for IPv6 address

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import multiprocessing
import multiprocessing.connection
import time
import weakref
@@ -10,7 +11,7 @@ import zmq
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.network_utils import get_tcp_uri, make_zmq_socket
from vllm.utils.system_utils import get_mp_context, set_process_title
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
@@ -55,6 +56,25 @@ class DPCoordinator:
request wave / running state changes.
"""
def _wait_for_zmq_addrs(self, zmq_addr_pipe) -> tuple[str, str, str]:
try:
ready = multiprocessing.connection.wait(
[zmq_addr_pipe, self.proc.sentinel], timeout=30
)
if not ready:
raise RuntimeError(
"DP Coordinator process failed to report ZMQ addresses "
"during startup."
)
try:
return zmq_addr_pipe.recv()
except EOFError:
raise RuntimeError(
"DP Coordinator process failed during startup."
) from None
finally:
zmq_addr_pipe.close()
def __init__(
self, parallel_config: ParallelConfig, enable_wave_coordination: bool = True
):
@@ -66,18 +86,24 @@ class DPCoordinator:
# Assume coordinator is colocated with front-end procs when not in
# either external or hybrid DP LB mode.
local_only = not parallel_config.local_engines_only
front_publish_address = get_engine_client_zmq_addr(
local_only=local_only, host=host
)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
local_only_eng = False
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
def bind_address(local_only: bool) -> str:
return (
get_engine_client_zmq_addr(local_only=True, host=host)
if local_only
else get_tcp_uri(host, 0)
)
front_publish_address = bind_address(local_only)
back_publish_address = bind_address(local_only_eng)
back_output_address = bind_address(local_only_eng)
context = get_mp_context()
parent_zmq_addr_pipe, child_zmq_addr_pipe = context.Pipe(duplex=False)
self.proc: multiprocessing.Process = context.Process(
target=DPCoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
@@ -86,11 +112,18 @@ class DPCoordinator:
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
"zmq_addr_pipe": child_zmq_addr_pipe,
"enable_wave_coordination": enable_wave_coordination,
},
daemon=True,
)
self.proc.start()
child_zmq_addr_pipe.close()
(
front_publish_address,
back_output_address,
back_publish_address,
) = self._wait_for_zmq_addrs(parent_zmq_addr_pipe)
self.stats_publish_address = front_publish_address
self.coord_in_address = back_publish_address
@@ -136,6 +169,7 @@ class DPCoordinatorProc:
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
zmq_addr_pipe=None,
min_stats_update_interval_ms: int = 100,
enable_wave_coordination: bool = True,
):
@@ -149,15 +183,20 @@ class DPCoordinatorProc:
front_publish_address,
back_output_address,
back_publish_address,
zmq_addr_pipe,
)
except KeyboardInterrupt:
logger.info("DP Coordinator process exiting")
finally:
if zmq_addr_pipe is not None:
zmq_addr_pipe.close()
def process_input_socket(
self,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
zmq_addr_pipe=None,
):
decoder = MsgpackDecoder(EngineCoreOutputs)
@@ -191,6 +230,17 @@ class DPCoordinatorProc:
bind=True,
) as publish_back,
):
if zmq_addr_pipe is not None:
try:
zmq_addr_pipe.send(
(
publish_front.getsockopt(zmq.LAST_ENDPOINT).decode(),
output_back.getsockopt(zmq.LAST_ENDPOINT).decode(),
publish_back.getsockopt(zmq.LAST_ENDPOINT).decode(),
)
)
finally:
zmq_addr_pipe.close()
# Wait until all engines subscribe.
for _ in self.engines:
if publish_back.recv() != b"\x01":