[Bugfix] Fix MessageQueue connect_ip for cross-node data parallelism (#35429)
Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
79
tests/distributed/test_mq_connect_ip.py
Normal file
79
tests/distributed/test_mq_connect_ip.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test that MessageQueue uses the local node's IP for binding,
|
||||
not a remote master_addr. This validates the fix for cross-node
|
||||
data-parallel where each DP group leader must bind to its own IP.
|
||||
|
||||
The bug: multiproc_executor used `parallel_config.master_addr` as
|
||||
`connect_ip` for every DP group's MessageQueue. For DP groups whose
|
||||
leader is NOT on the master node, binding to master_addr fails with
|
||||
"Cannot assign requested address".
|
||||
|
||||
The fix: use `get_ip()` (local node IP) instead of `master_addr`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
|
||||
def test_mq_bind_with_local_ip():
|
||||
"""MessageQueue with remote readers should successfully bind
|
||||
when connect_ip is the local node's IP."""
|
||||
# n_reader=2, n_local_reader=1 means 1 remote reader,
|
||||
# which triggers the remote ZMQ socket bind.
|
||||
mq = MessageQueue(
|
||||
n_reader=2,
|
||||
n_local_reader=1,
|
||||
connect_ip=get_ip(),
|
||||
)
|
||||
handle = mq.export_handle()
|
||||
assert handle.remote_subscribe_addr is not None
|
||||
# The bound address should contain our local IP
|
||||
local_ip = get_ip()
|
||||
assert (
|
||||
local_ip in handle.remote_subscribe_addr
|
||||
or f"[{local_ip}]" in handle.remote_subscribe_addr
|
||||
)
|
||||
del mq
|
||||
|
||||
|
||||
def test_mq_bind_with_non_local_ip_fails():
|
||||
"""MessageQueue should fail to bind when connect_ip is a
|
||||
non-local IP address (simulating the bug where master_addr
|
||||
from a different node was used)."""
|
||||
# Use a non-local IP that we definitely can't bind to.
|
||||
# 198.51.100.1 is from TEST-NET-2 (RFC 5737), never locally assigned.
|
||||
non_local_ip = "198.51.100.1"
|
||||
with pytest.raises(zmq.error.ZMQError, match="Cannot assign requested address"):
|
||||
MessageQueue(
|
||||
n_reader=2,
|
||||
n_local_reader=1,
|
||||
connect_ip=non_local_ip,
|
||||
)
|
||||
|
||||
|
||||
def test_mq_bind_defaults_to_local_ip():
|
||||
"""When connect_ip is None, MessageQueue should auto-detect
|
||||
the local IP and bind successfully."""
|
||||
mq = MessageQueue(
|
||||
n_reader=2,
|
||||
n_local_reader=1,
|
||||
connect_ip=None, # should fallback to get_ip()
|
||||
)
|
||||
handle = mq.export_handle()
|
||||
assert handle.remote_subscribe_addr is not None
|
||||
del mq
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mq_bind_with_local_ip()
|
||||
print("PASSED: test_mq_bind_with_local_ip")
|
||||
test_mq_bind_with_non_local_ip_fails()
|
||||
print("PASSED: test_mq_bind_with_non_local_ip_fails")
|
||||
test_mq_bind_defaults_to_local_ip()
|
||||
print("PASSED: test_mq_bind_defaults_to_local_ip")
|
||||
print("\nAll tests passed!")
|
||||
@@ -44,6 +44,7 @@ from vllm.logger import init_logger
|
||||
from vllm.tracing import instrument, maybe_init_worker_tracer
|
||||
from vllm.utils.network_utils import (
|
||||
get_distributed_init_method,
|
||||
get_ip,
|
||||
get_loopback_ip,
|
||||
get_open_port,
|
||||
)
|
||||
@@ -128,11 +129,23 @@ class MultiprocExecutor(Executor):
|
||||
# For leader node within each dp rank,
|
||||
# each dp will have its own leader multiproc executor.
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
mq_connect_ip = get_ip()
|
||||
logger.info(
|
||||
"DP group leader: node_rank=%d, node_rank_within_dp=%d, "
|
||||
"master_addr=%s, mq_connect_ip=%s (local), "
|
||||
"world_size=%d, local_world_size=%d",
|
||||
self.parallel_config.node_rank,
|
||||
self.parallel_config.node_rank_within_dp,
|
||||
self.parallel_config.master_addr,
|
||||
mq_connect_ip,
|
||||
self.world_size,
|
||||
self.local_world_size,
|
||||
)
|
||||
self.rpc_broadcast_mq = MessageQueue(
|
||||
self.world_size,
|
||||
self.local_world_size,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
connect_ip=self.parallel_config.master_addr,
|
||||
connect_ip=mq_connect_ip,
|
||||
)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
# Create workers
|
||||
|
||||
Reference in New Issue
Block a user