diff --git a/.buildkite/test_areas/expert_parallelism.yaml b/.buildkite/test_areas/expert_parallelism.yaml index 1443d847e..63404fc5d 100644 --- a/.buildkite/test_areas/expert_parallelism.yaml +++ b/.buildkite/test_areas/expert_parallelism.yaml @@ -24,8 +24,7 @@ steps: - label: Elastic EP Scaling Test timeout_in_minutes: 20 - device: b200 - optional: true + device: h100 working_dir: "/vllm-workspace/tests" num_devices: 4 source_file_dependencies: diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index d4048a473..add011ca4 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +import socket from collections.abc import Callable from typing import TYPE_CHECKING, Any, Literal, overload @@ -266,33 +267,9 @@ class ParallelConfig: Set to be private as it's not intended to be configured by users. """ - _stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list) - """List of open ports for stateless DP groups when enable_elastic_ep is True. - Set to be private as it's not intended to be configured by users. - It is a list of list[int], with each inner list contains a set of 3 ports - to be used for setting up the stateless CPU/device/TCPStore groups - in StatelessGroupCoordinator. The number of inner lists is equal to - the number of DP groups, - i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size, - and len(self._stateless_dp_group_port_list[i]) == 3 for all i. - """ - - _stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list) - """List of open ports for stateless EP groups when enable_elastic_ep is True. - Set to be private as it's not intended to be configured by users. - len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size, - """ - - _stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list) - """List of open ports for stateless EPLB groups when enable_elastic_ep is True. - Same topology as EP but separate NCCL communicator to avoid deadlocks. - """ - - _stateless_world_group_port_list: list[list[int]] = Field(default_factory=list) - """List of open ports for stateless world group when enable_elastic_ep is True. - Set to be private as it's not intended to be configured by users. - len(self._stateless_world_group_port_list) == 1, - """ + _coord_store_port: int = 0 + """Port of the coordination TCPStore. Can be set by the API server; workers + connect as clients to exchange self-picked group ports at runtime.""" decode_context_parallel_size: int = 1 """Number of decode context parallel groups, because the world size does @@ -465,65 +442,32 @@ class ParallelConfig: return answer - def allocate_elastic_ep_ports(self) -> None: - """Allocate all ports for elastic EP (stateless groups + DP master). + def _pick_stateless_dp_port(self) -> tuple[int, socket.socket | None]: + """Return ``(port, listen_socket)`` for DP group init. - Must be called AFTER ray.init() so that ports claimed by Ray's - idle worker pool are already in use and won't be returned by - get_open_ports_list(). + With a coord store, rank 0 binds a socket and publishes the port; + others read it. Without one, pops a pre-allocated port and + returns ``listen_socket=None``. """ - if not self.enable_elastic_ep: - return - if self._stateless_world_group_port_list: - return + if not self._coord_store_port: + return self.get_next_dp_init_port(), None - num_world_groups = 1 - dp_size = self.data_parallel_size - ep_size = self.data_parallel_size * self.world_size_across_dp - num_dp_groups = max(1, self.world_size_across_dp // dp_size) - num_ep_groups = max(1, self.world_size_across_dp // ep_size) - num_eplb_groups = num_ep_groups - total_stateless_ports = ( - num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups - ) * 3 - num_dp_master_ports = 5 + from vllm.distributed.utils import get_cached_tcp_store_client - all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports) + store = get_cached_tcp_store_client( + self.data_parallel_master_ip, self._coord_store_port + ) - self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:] - self.data_parallel_master_port = self._data_parallel_master_port_list.pop() - all_ports = all_ports[:-num_dp_master_ports] - - self._stateless_world_group_port_list = [ - all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3) - ] - start_idx = num_world_groups * 3 - self._stateless_dp_group_port_list = [ - all_ports[i : i + 3] - for i in range(start_idx, start_idx + num_dp_groups * 3, 3) - ] - start_idx += num_dp_groups * 3 - self._stateless_ep_group_port_list = [ - all_ports[i : i + 3] - for i in range(start_idx, start_idx + num_ep_groups * 3, 3) - ] - start_idx += num_ep_groups * 3 - self._stateless_eplb_group_port_list = [ - all_ports[i : i + 3] - for i in range(start_idx, start_idx + num_eplb_groups * 3, 3) - ] - - def get_next_stateless_world_group_port(self) -> list[int]: - return self._stateless_world_group_port_list.pop() - - def get_next_stateless_dp_group_port(self) -> list[int]: - return self._stateless_dp_group_port_list.pop() - - def get_next_stateless_ep_group_port(self) -> list[int]: - return self._stateless_ep_group_port_list.pop() - - def get_next_stateless_eplb_group_port(self) -> list[int]: - return self._stateless_eplb_group_port_list.pop() + key = "dp_master_port" + if self.data_parallel_rank == 0: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind((self.data_parallel_master_ip, 0)) + s.listen() + port = s.getsockname()[1] + store.set(key, str(port).encode()) + return port, s + else: + return int(store.get(key).decode()), None @overload def stateless_init_dp_group( @@ -553,14 +497,16 @@ class ParallelConfig: last_exc: Exception | None = None for _ in range(max_retries): try: + port, listen_socket = self._pick_stateless_dp_port() # use gloo since the engine process might not have cuda device return stateless_init_torch_distributed_process_group( self.data_parallel_master_ip, - self.get_next_dp_init_port(), + port, self.data_parallel_rank, self.data_parallel_size, backend="gloo", return_store=return_store, + listen_socket=listen_socket, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py index 516d2c256..00ac6d84b 100644 --- a/vllm/distributed/elastic_ep/elastic_execute.py +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -162,10 +162,8 @@ class ElasticEPScalingExecutor: new_dp_size=new_dp_size, new_world_size_across_dp=new_world_size_across_dp, master_ip=reconfig_request.new_data_parallel_master_ip, - world_group_ports=reconfig_request.new_stateless_world_group_port_list, - dp_group_ports=reconfig_request.new_stateless_dp_group_port_list, - ep_group_ports=reconfig_request.new_stateless_ep_group_port_list, - eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list, + coord_store_port=reconfig_request.coord_store_port, + enable_eplb=updated_config.parallel_config.enable_eplb, ) self.worker.model_runner.eep_eplb_suppressed = True standby_ep_group = get_standby_ep_group() diff --git a/vllm/distributed/elastic_ep/elastic_state.py b/vllm/distributed/elastic_ep/elastic_state.py index fce0d8361..cd989a49a 100644 --- a/vllm/distributed/elastic_ep/elastic_state.py +++ b/vllm/distributed/elastic_ep/elastic_state.py @@ -563,15 +563,4 @@ class ElasticEPScalingState: parallel_config._data_parallel_master_port_list = ( reconfig_request.new_data_parallel_master_port_list ) - parallel_config._stateless_world_group_port_list = ( - reconfig_request.new_stateless_world_group_port_list - ) - parallel_config._stateless_dp_group_port_list = ( - reconfig_request.new_stateless_dp_group_port_list - ) - parallel_config._stateless_ep_group_port_list = ( - reconfig_request.new_stateless_ep_group_port_list - ) - parallel_config._stateless_eplb_group_port_list = ( - reconfig_request.new_stateless_eplb_group_port_list - ) + parallel_config._coord_store_port = reconfig_request.coord_store_port diff --git a/vllm/distributed/elastic_ep/standby_state.py b/vllm/distributed/elastic_ep/standby_state.py index d11e0b550..846793a95 100644 --- a/vllm/distributed/elastic_ep/standby_state.py +++ b/vllm/distributed/elastic_ep/standby_state.py @@ -38,10 +38,8 @@ def create_standby_groups( new_dp_size: int, new_world_size_across_dp: int, master_ip: str, - world_group_ports: list[list[int]], - dp_group_ports: list[list[int]], - ep_group_ports: list[list[int]], - eplb_group_ports: list[list[int]] | None = None, + coord_store_port: int, + enable_eplb: bool = True, backend: str | None = None, ) -> None: global \ @@ -51,19 +49,23 @@ def create_standby_groups( _STANDBY_EP, \ _STANDBY_EPLB + from vllm.distributed.utils import get_cached_tcp_store_client + assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size world_group = get_world_group() assert isinstance(world_group, StatelessGroupCoordinator) backend = backend or world_group.backend + coord_store = get_cached_tcp_store_client(master_ip, coord_store_port) + standby_world_ranks = [list(range(new_world_size_across_dp))] _STANDBY_WORLD = _init_stateless_group( standby_world_ranks, "world", - world_group_ports, master_ip, backend, use_device_communicator=False, + coord_store=coord_store, ) _STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group) @@ -76,7 +78,7 @@ def create_standby_groups( standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0) standby_dp_ranks = [x.tolist() for x in standby_dp_ranks] _STANDBY_DP = _init_stateless_group( - standby_dp_ranks, "dp", dp_group_ports, master_ip, backend + standby_dp_ranks, "dp", master_ip, backend, coord_store=coord_store ) standby_ep_ranks = ( @@ -84,12 +86,16 @@ def create_standby_groups( ) standby_ep_ranks = [x.tolist() for x in standby_ep_ranks] _STANDBY_EP = _init_stateless_group( - standby_ep_ranks, "ep", ep_group_ports, master_ip, backend + standby_ep_ranks, "ep", master_ip, backend, coord_store=coord_store ) - if eplb_group_ports is not None: + if enable_eplb: _STANDBY_EPLB = _init_stateless_group( - standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend + standby_ep_ranks, + "eplb", + master_ip, + backend, + coord_store=coord_store, ) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index af1bc6b14..04187b34e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -40,13 +40,16 @@ import torch import torch.distributed import torch.distributed._functional_collectives as funcol import torch.distributed._symmetric_memory -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend, ProcessGroup, Store import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase, ) -from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.utils import ( + StatelessProcessGroup, + get_cached_tcp_store_client, +) from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.network_utils import get_distributed_init_method @@ -1164,9 +1167,9 @@ def init_model_parallel_group( def _init_stateless_group( group_ranks: list[list[int]], group_name: str, - group_ports: list[list[int]], host: str, backend: str, + coord_store: Store, use_device_communicator: bool = True, ) -> "StatelessGroupCoordinator": """Create a StatelessGroupCoordinator with the given parameters.""" @@ -1180,7 +1183,7 @@ def _init_stateless_group( use_device_communicator=use_device_communicator, group_name=group_name, host=host, - group_ports=group_ports, + coord_store=coord_store, global_rank=world.rank, global_world_size=world.world_size, ) @@ -1321,7 +1324,9 @@ def _init_elastic_ep_world( group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)] if global_rank in all_ranks: group_ranks = [all_ranks] - group_ports = [parallel_config.get_next_stateless_world_group_port()] + coord_store = get_cached_tcp_store_client( + parallel_config.data_parallel_master_ip, parallel_config._coord_store_port + ) world = StatelessGroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, @@ -1329,7 +1334,7 @@ def _init_elastic_ep_world( use_device_communicator=False, group_name="world", host=parallel_config.data_parallel_master_ip, - group_ports=group_ports, + coord_store=coord_store, global_rank=global_rank, global_world_size=global_world_size, ) @@ -1513,7 +1518,13 @@ def initialize_model_parallel( config = get_current_vllm_config() data_parallel_size = config.parallel_config.data_parallel_size enable_elastic_ep = config.parallel_config.enable_elastic_ep + parallel_config = config.parallel_config + coord_store: Store | None = None if enable_elastic_ep: + coord_store = get_cached_tcp_store_client( + parallel_config.data_parallel_master_ip, + parallel_config._coord_store_port, + ) # Use stateless world group for global information world_size = get_world_group().world_size rank = get_world_group().rank @@ -1633,16 +1644,12 @@ def initialize_model_parallel( group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] if enable_elastic_ep: - parallel_config = config.parallel_config - dp_ports = [ - parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks - ] _DP = _init_stateless_group( group_ranks, "dp", - dp_ports, parallel_config.data_parallel_master_ip, backend, + coord_store=coord_store, ) else: _DP = init_model_parallel_group( @@ -1665,16 +1672,12 @@ def initialize_model_parallel( ) group_ranks = [x.tolist() for x in group_ranks] if enable_elastic_ep: - parallel_config = config.parallel_config - ep_ports = [ - parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks - ] _EP = _init_stateless_group( group_ranks, "ep", - ep_ports, parallel_config.data_parallel_master_ip, backend, + coord_store=coord_store, ) else: _EP = init_model_parallel_group( @@ -1693,16 +1696,12 @@ def initialize_model_parallel( and config.parallel_config.enable_eplb ): if enable_elastic_ep: - eplb_ports = [ - parallel_config.get_next_stateless_eplb_group_port() - for _ in group_ranks - ] _EPLB = _init_stateless_group( group_ranks, "eplb", - eplb_ports, parallel_config.data_parallel_master_ip, backend, + coord_store=coord_store, ) else: _EPLB = init_model_parallel_group( diff --git a/vllm/distributed/stateless_coordinator.py b/vllm/distributed/stateless_coordinator.py index f2126fdba..549284df3 100644 --- a/vllm/distributed/stateless_coordinator.py +++ b/vllm/distributed/stateless_coordinator.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import socket +import struct from typing import Any, Optional import torch -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend, ProcessGroup, Store from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator from vllm.distributed.parallel_state import ( @@ -23,6 +25,38 @@ from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) +_PORTS_FMT = "!3I" + + +def _allocate_group_ports( + key: str, + host: str, + coord_store: Store, +) -> tuple[list[int], list[socket.socket]]: + """Bind 3 sockets and publish the ports to *coord_store*. + + Called by rank 0 only. Returns ``(ports, sockets)`` with the + sockets still open. + """ + socks: list[socket.socket] = [] + ports: list[int] = [] + for _ in range(3): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind((host, 0)) + s.listen() + socks.append(s) + ports.append(s.getsockname()[1]) + coord_store.set(key, struct.pack(_PORTS_FMT, *ports)) + return ports, socks + + +def _fetch_group_ports(key: str, coord_store: Store) -> list[int]: + """Read 3 ports published by rank 0 from *coord_store*. + + Blocks until the key is available. + """ + return list(struct.unpack(_PORTS_FMT, coord_store.get(key))) + class StatelessGroupCoordinator(GroupCoordinator): """ @@ -39,10 +73,10 @@ class StatelessGroupCoordinator(GroupCoordinator): local_rank: int, torch_distributed_backend: str | Backend, use_device_communicator: bool, + coord_store: Store, use_message_queue_broadcaster: bool = False, group_name: str | None = None, host: str = "127.0.0.1", - group_ports: list[list[int]] | None = None, global_rank: int = 0, global_world_size: int = 1, ): @@ -61,17 +95,23 @@ class StatelessGroupCoordinator(GroupCoordinator): backend = str(torch_distributed_backend) self.backend = backend - assert group_ports is not None, "group_ports is not provided" for idx, ranks in enumerate(group_ranks): if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) self.rank_in_group = ranks.index(self.rank) - ports = group_ports[idx] - device_port = ports[0] - cpu_port = ports[1] - tcp_store_port = ports[2] + key = f"{group_name}_{idx}" + if self.rank_in_group == 0: + ports, socks = _allocate_group_ports( + key, + host, + coord_store, + ) + else: + ports = _fetch_group_ports(key, coord_store) + socks = [] + device_port, cpu_port, tcp_store_port = ports device_group = stateless_init_torch_distributed_process_group( host=host, @@ -80,6 +120,7 @@ class StatelessGroupCoordinator(GroupCoordinator): world_size=self.world_size, backend=backend, group_name=f"{self.unique_name}_device", + listen_socket=socks[0] if socks else None, ) cpu_group = stateless_init_torch_distributed_process_group( host=host, @@ -88,12 +129,14 @@ class StatelessGroupCoordinator(GroupCoordinator): world_size=self.world_size, backend="gloo", group_name=f"{self.unique_name}_cpu", + listen_socket=socks[1] if socks else None, ) tcp_store_group = StatelessProcessGroup.create( host=host, port=tcp_store_port, rank=self.rank_in_group, world_size=self.world_size, + listen_socket=socks[2] if socks else None, ) self_device_group = device_group diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 102f2f727..9991ab1dd 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -6,6 +6,7 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses +import functools import os import pickle import socket @@ -139,6 +140,29 @@ def get_pp_indices( return (start_layer, end_layer) +def create_tcp_store( + host: str, + port: int, + listen_socket: socket.socket | None = None, + **kwargs: Any, +) -> TCPStore: + """Create a TCPStore, optionally taking ownership of ``listen_socket``.""" + if listen_socket is None: + return TCPStore(host_name=host, port=port, **kwargs) + + listen_fd = listen_socket.detach() + try: + return TCPStore( + host_name=host, + port=port, + master_listen_fd=listen_fd, + **kwargs, + ) + except Exception: + socket.close(listen_fd) + raise + + @dataclasses.dataclass class StatelessProcessGroup: """A dataclass to hold a metadata store, and the rank, world_size of the @@ -150,9 +174,6 @@ class StatelessProcessGroup: world_size: int store: torch._C._distributed_c10d.Store - # stores a reference to the socket so that the file descriptor stays alive - socket: socket.socket | None - data_expiration_seconds: int = 3600 # 1 hour # dst rank -> counter @@ -419,6 +440,7 @@ class StatelessProcessGroup: world_size: int, data_expiration_seconds: int = 3600, store_timeout: int = 300, + listen_socket: socket.socket | None = None, ) -> "StatelessProcessGroup": """A replacement for `torch.distributed.init_process_group` that does not pollute the global state. @@ -436,36 +458,39 @@ class StatelessProcessGroup: C, and D can call `StatelessProcessGroup.create` to form another group. """ # noqa launch_server = rank == 0 - if launch_server: - # listen on the specified interface (instead of 0.0.0.0) + if launch_server and listen_socket is None: listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listen_socket.bind((host, port)) listen_socket.listen() - listen_fd = listen_socket.fileno() - else: - listen_socket = None - listen_fd = None - - store = TCPStore( - host_name=host, - port=port, + store = create_tcp_store( + host, + port, + listen_socket=listen_socket, world_size=world_size, is_master=launch_server, timeout=timedelta(seconds=store_timeout), use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 - master_listen_fd=listen_fd, ) return StatelessProcessGroup( rank=rank, world_size=world_size, store=store, - socket=listen_socket, data_expiration_seconds=data_expiration_seconds, ) +@functools.lru_cache(maxsize=1) +def get_cached_tcp_store_client(host: str, port: int) -> TCPStore: + """Return a cached TCPStore client. + + Cached so that every call with the same ``(host, port)`` reuses the + same connection. A new ``(host, port)`` evicts the old entry. + """ + return TCPStore(host, port, is_master=False, wait_for_workers=False) + + def init_gloo_process_group( prefix_store: PrefixStore, group_rank: int, @@ -504,6 +529,7 @@ def stateless_init_torch_distributed_process_group( backend: str, group_name: str | None = None, return_store: bool = False, + listen_socket: socket.socket | None = None, ) -> ProcessGroup | tuple[ProcessGroup, Store]: """ A replacement for `torch.distributed.init_process_group` that does not @@ -535,14 +561,30 @@ def stateless_init_torch_distributed_process_group( are the same as process 1 and 5, the main communication channel is always formed with process 1, 2, ..., 8, and the additional communication channel is formed with process 9 and 10. + + When *listen_socket* is provided, the rendezvous step + is skipped and a ``TCPStore`` server is created directly using the + pre-bound socket. This is useful for eliminating TOCTOU races + between port allocation and binding. """ init_method = get_tcp_uri(host, port) backend = Backend(backend) # it is basically string timeout = _get_default_timeout(backend) - store, rank, world_size = next( - rendezvous(init_method, rank, world_size, timeout=timeout) - ) + if listen_socket is not None: + store = create_tcp_store( + host, + port, + listen_socket=listen_socket, + world_size=world_size, + is_master=True, + timeout=timeout, + multi_tenant=True, + ) + else: + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout) + ) store.set_timeout(timeout) group_rank = rank diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index d76948bc2..114d45fc4 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -237,10 +237,7 @@ class ReconfigureDistributedRequest(msgspec.Struct): new_data_parallel_master_ip: str new_data_parallel_master_port: int new_data_parallel_master_port_list: list[int] - new_stateless_world_group_port_list: list[list[int]] - new_stateless_dp_group_port_list: list[list[int]] - new_stateless_ep_group_port_list: list[list[int]] - new_stateless_eplb_group_port_list: list[list[int]] + coord_store_port: int class ReconfigureRankType(enum.IntEnum): diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 2f2acdd37..7d962f740 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1767,6 +1767,7 @@ class DPEngineCoreProc(EngineCoreProc): new_parallel_config._data_parallel_master_port_list = ( reconfig_request.new_data_parallel_master_port_list ) + new_parallel_config._coord_store_port = reconfig_request.coord_store_port is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size is_shutdown = ( diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 4596824ec..91664058d 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -455,56 +455,6 @@ class ElasticScalingCache: pending_notifications: dict[EEPNotificationType, set[int]] -def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int): - """ - Allocate stateless group ports for elastic EP. - """ - from vllm.utils.network_utils import get_open_ports_list - - assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled" - world_size = parallel_config.world_size - new_world_size_across_dp = world_size * new_data_parallel_size - num_world_groups = 1 - num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size) - num_ep_groups = max( - 1, - new_world_size_across_dp - // (new_data_parallel_size * parallel_config.tensor_parallel_size), - ) - num_eplb_groups = num_ep_groups - total_ports_needed = ( - num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups - ) * 3 + 5 - all_ports = get_open_ports_list(total_ports_needed) - new_data_parallel_master_port_list = all_ports[-5:] - all_ports = all_ports[:-5] - new_stateless_world_group_port_list = [ - all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3) - ] - start_idx = num_world_groups * 3 - new_stateless_dp_group_port_list = [ - all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3) - ] - start_idx += num_dp_groups * 3 - new_stateless_ep_group_port_list = [ - all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3) - ] - start_idx += num_ep_groups * 3 - new_stateless_eplb_group_port_list = [ - all_ports[i : i + 3] - for i in range(start_idx, start_idx + num_eplb_groups * 3, 3) - ] - - parallel_config._stateless_world_group_port_list = ( - new_stateless_world_group_port_list - ) - parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list - parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list - parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list - parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop() - parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list - - class MPClient(EngineCoreClient): """ MPClient: base client for multi-proc EngineCore. @@ -1541,6 +1491,28 @@ class DPLBAsyncMPClient(DPAsyncMPClient): self._ensure_output_queue_task() await future + def _setup_elastic_ep_reconfig_bootstrap(self) -> tuple[str, int]: + from vllm.distributed.utils import create_tcp_store + from vllm.utils.network_utils import get_open_ports_list + + parallel_config = self.vllm_config.parallel_config + parallel_config._data_parallel_master_port_list = get_open_ports_list(5) + parallel_config.data_parallel_master_port = ( + parallel_config._data_parallel_master_port_list.pop() + ) + + ip = parallel_config.data_parallel_master_ip + store = create_tcp_store( + ip, + 0, + is_master=True, + world_size=-1, + wait_for_workers=False, + ) + parallel_config._coord_store_port = store.port + self._coord_store = store + return ip, store.port + async def _scale_up_elastic_ep( self, cur_data_parallel_size: int, new_data_parallel_size: int ) -> None: @@ -1555,7 +1527,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ) parallel_config = self.vllm_config.parallel_config - allocate_stateless_group_ports(parallel_config, new_data_parallel_size) + ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap() # Phase 1: Send reconfig messages to existing engines reconfig_futures = [] @@ -1564,13 +1536,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, + new_data_parallel_master_ip=ip, new_data_parallel_master_port=parallel_config.data_parallel_master_port, new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, - new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, - new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list, - new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list, - new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list, + coord_store_port=coord_store_port, ) coro = self._call_utility_async( "reinitialize_distributed", reconfig_request, engine=engine @@ -1650,7 +1619,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ) parallel_config = self.vllm_config.parallel_config - allocate_stateless_group_ports(parallel_config, new_data_parallel_size) + ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap() reconfig_futures = [] for cur_dp_rank, engine in enumerate(self.core_engines): @@ -1658,13 +1627,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient): new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, + new_data_parallel_master_ip=ip, new_data_parallel_master_port=parallel_config.data_parallel_master_port, new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, - new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, - new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list, - new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list, - new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list, + coord_store_port=coord_store_port, ) if cur_dp_rank >= new_data_parallel_size: reconfig_request.new_data_parallel_rank = ( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index fb1c45946..52c721734 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -301,7 +301,20 @@ class CoreEngineActorManager: else: ray.init() - vllm_config.parallel_config.allocate_elastic_ep_ports() + parallel_config = vllm_config.parallel_config + if parallel_config.enable_elastic_ep: + from vllm.distributed.utils import create_tcp_store + + ip = parallel_config.data_parallel_master_ip + store = create_tcp_store( + ip, + 0, + is_master=True, + world_size=-1, + wait_for_workers=False, + ) + parallel_config._coord_store_port = store.port + self._coord_store = store if placement_groups is not None: assert local_dp_ranks is not None, (