diff --git a/tests/utils_/test_network_utils.py b/tests/utils_/test_network_utils.py index bc274f067..157d43cb8 100644 --- a/tests/utils_/test_network_utils.py +++ b/tests/utils_/test_network_utils.py @@ -7,6 +7,7 @@ import zmq from vllm.utils.network_utils import ( get_open_port, + get_open_ports_list, get_tcp_uri, join_host_port, make_zmq_path, @@ -28,6 +29,25 @@ def test_get_open_port(monkeypatch: pytest.MonkeyPatch): s3.bind(("localhost", get_open_port())) +def test_get_open_ports_list_with_vllm_port(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_PORT", "5678") + ports = get_open_ports_list(5) + assert len(ports) == 5 + assert len(set(ports)) == 5, "ports must be unique" + + # verify every port is actually bindable + sockets = [] + try: + for p in ports: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("localhost", p)) + sockets.append(s) + finally: + for s in sockets: + s.close() + + @pytest.mark.parametrize( "path,expected", [ diff --git a/vllm/utils/network_utils.py b/vllm/utils/network_utils.py index 7d01533cb..6ffae768e 100644 --- a/vllm/utils/network_utils.py +++ b/vllm/utils/network_utils.py @@ -167,16 +167,34 @@ def get_open_port() -> int: def get_open_ports_list(count: int = 5) -> list[int]: - """Get a list of open ports.""" - ports = set[int]() - while len(ports) < count: - ports.add(get_open_port()) - return list(ports) + """Get a list of unique open ports. + + When VLLM_PORT is set, scans upward from that port, advancing + the start position after each find so every port is unique. + """ + ports_set = set[int]() + if envs.VLLM_PORT is not None: + next_port = envs.VLLM_PORT + for _ in range(count): + port = _get_open_port(start_port=next_port, max_attempts=1000) + ports_set.add(port) + next_port = port + 1 + return list(ports_set) + else: + while len(ports_set) < count: + ports_set.add(get_open_port()) + + return list(ports_set) -def _get_open_port() -> int: - port = envs.VLLM_PORT +def _get_open_port( + start_port: int | None = None, + max_attempts: int | None = None, +) -> int: + start_port = start_port if start_port is not None else envs.VLLM_PORT + port = start_port if port is not None: + attempts = 0 while True: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -185,6 +203,12 @@ def _get_open_port() -> int: except OSError: port += 1 # Increment port number if already in use logger.info("Port %d is already in use, trying port %d", port - 1, port) + attempts += 1 + if max_attempts is not None and attempts >= max_attempts: + raise RuntimeError( + f"Could not find open port after {max_attempts} " + f"attempts starting from port {start_port}" + ) # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: