Fix DP coordinator ZMQ port TOCTOU (#37452)
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
This commit is contained in:
@@ -247,7 +247,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]:
|
||||
|
||||
scheme = parsed.scheme
|
||||
host = parsed.hostname or ""
|
||||
port = str(parsed.port or "")
|
||||
port = "" if parsed.port is None else str(parsed.port)
|
||||
if host.startswith("[") and host.endswith("]"):
|
||||
host = host[1:-1] # Remove brackets for IPv6 address
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import multiprocessing
|
||||
import multiprocessing.connection
|
||||
import time
|
||||
import weakref
|
||||
|
||||
@@ -10,7 +11,7 @@ import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.utils.network_utils import get_tcp_uri, make_zmq_socket
|
||||
from vllm.utils.system_utils import get_mp_context, set_process_title
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
@@ -55,6 +56,25 @@ class DPCoordinator:
|
||||
request wave / running state changes.
|
||||
"""
|
||||
|
||||
def _wait_for_zmq_addrs(self, zmq_addr_pipe) -> tuple[str, str, str]:
|
||||
try:
|
||||
ready = multiprocessing.connection.wait(
|
||||
[zmq_addr_pipe, self.proc.sentinel], timeout=30
|
||||
)
|
||||
if not ready:
|
||||
raise RuntimeError(
|
||||
"DP Coordinator process failed to report ZMQ addresses "
|
||||
"during startup."
|
||||
)
|
||||
try:
|
||||
return zmq_addr_pipe.recv()
|
||||
except EOFError:
|
||||
raise RuntimeError(
|
||||
"DP Coordinator process failed during startup."
|
||||
) from None
|
||||
finally:
|
||||
zmq_addr_pipe.close()
|
||||
|
||||
def __init__(
|
||||
self, parallel_config: ParallelConfig, enable_wave_coordination: bool = True
|
||||
):
|
||||
@@ -66,18 +86,24 @@ class DPCoordinator:
|
||||
# Assume coordinator is colocated with front-end procs when not in
|
||||
# either external or hybrid DP LB mode.
|
||||
local_only = not parallel_config.local_engines_only
|
||||
front_publish_address = get_engine_client_zmq_addr(
|
||||
local_only=local_only, host=host
|
||||
)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
# NOTE(yongji): handling scaling from intra-node to inter-node
|
||||
if parallel_config.enable_elastic_ep:
|
||||
local_only_eng = False
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
|
||||
def bind_address(local_only: bool) -> str:
|
||||
return (
|
||||
get_engine_client_zmq_addr(local_only=True, host=host)
|
||||
if local_only
|
||||
else get_tcp_uri(host, 0)
|
||||
)
|
||||
|
||||
front_publish_address = bind_address(local_only)
|
||||
back_publish_address = bind_address(local_only_eng)
|
||||
back_output_address = bind_address(local_only_eng)
|
||||
|
||||
context = get_mp_context()
|
||||
parent_zmq_addr_pipe, child_zmq_addr_pipe = context.Pipe(duplex=False)
|
||||
self.proc: multiprocessing.Process = context.Process(
|
||||
target=DPCoordinatorProc.run_coordinator,
|
||||
name="VLLM_DP_Coordinator",
|
||||
@@ -86,11 +112,18 @@ class DPCoordinator:
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
"zmq_addr_pipe": child_zmq_addr_pipe,
|
||||
"enable_wave_coordination": enable_wave_coordination,
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
self.proc.start()
|
||||
child_zmq_addr_pipe.close()
|
||||
(
|
||||
front_publish_address,
|
||||
back_output_address,
|
||||
back_publish_address,
|
||||
) = self._wait_for_zmq_addrs(parent_zmq_addr_pipe)
|
||||
|
||||
self.stats_publish_address = front_publish_address
|
||||
self.coord_in_address = back_publish_address
|
||||
@@ -136,6 +169,7 @@ class DPCoordinatorProc:
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
zmq_addr_pipe=None,
|
||||
min_stats_update_interval_ms: int = 100,
|
||||
enable_wave_coordination: bool = True,
|
||||
):
|
||||
@@ -149,15 +183,20 @@ class DPCoordinatorProc:
|
||||
front_publish_address,
|
||||
back_output_address,
|
||||
back_publish_address,
|
||||
zmq_addr_pipe,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DP Coordinator process exiting")
|
||||
finally:
|
||||
if zmq_addr_pipe is not None:
|
||||
zmq_addr_pipe.close()
|
||||
|
||||
def process_input_socket(
|
||||
self,
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
zmq_addr_pipe=None,
|
||||
):
|
||||
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
@@ -191,6 +230,17 @@ class DPCoordinatorProc:
|
||||
bind=True,
|
||||
) as publish_back,
|
||||
):
|
||||
if zmq_addr_pipe is not None:
|
||||
try:
|
||||
zmq_addr_pipe.send(
|
||||
(
|
||||
publish_front.getsockopt(zmq.LAST_ENDPOINT).decode(),
|
||||
output_back.getsockopt(zmq.LAST_ENDPOINT).decode(),
|
||||
publish_back.getsockopt(zmq.LAST_ENDPOINT).decode(),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
zmq_addr_pipe.close()
|
||||
# Wait until all engines subscribe.
|
||||
for _ in self.engines:
|
||||
if publish_back.recv() != b"\x01":
|
||||
|
||||
Reference in New Issue
Block a user