elastic_ep: Fix stateless group port races (#36330)

Signed-off-by: Itay Alroy <ialroy@nvidia.com>
This commit is contained in:
Itay Alroy
2026-03-18 16:36:18 +02:00
committed by GitHub
parent 99267c23ca
commit de1a86b7de
12 changed files with 221 additions and 222 deletions

View File

@@ -24,8 +24,7 @@ steps:
- label: Elastic EP Scaling Test
timeout_in_minutes: 20
device: b200
optional: true
device: h100
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import socket
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, overload
@@ -266,33 +267,9 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users.
"""
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
It is a list of list[int], with each inner list contains a set of 3 ports
to be used for setting up the stateless CPU/device/TCPStore groups
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
"""
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
"""
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
Same topology as EP but separate NCCL communicator to avoid deadlocks.
"""
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_world_group_port_list) == 1,
"""
_coord_store_port: int = 0
"""Port of the coordination TCPStore. Can be set by the API server; workers
connect as clients to exchange self-picked group ports at runtime."""
decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
@@ -465,65 +442,32 @@ class ParallelConfig:
return answer
def allocate_elastic_ep_ports(self) -> None:
"""Allocate all ports for elastic EP (stateless groups + DP master).
def _pick_stateless_dp_port(self) -> tuple[int, socket.socket | None]:
"""Return ``(port, listen_socket)`` for DP group init.
Must be called AFTER ray.init() so that ports claimed by Ray's
idle worker pool are already in use and won't be returned by
get_open_ports_list().
With a coord store, rank 0 binds a socket and publishes the port;
others read it. Without one, pops a pre-allocated port and
returns ``listen_socket=None``.
"""
if not self.enable_elastic_ep:
return
if self._stateless_world_group_port_list:
return
if not self._coord_store_port:
return self.get_next_dp_init_port(), None
num_world_groups = 1
dp_size = self.data_parallel_size
ep_size = self.data_parallel_size * self.world_size_across_dp
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
num_eplb_groups = num_ep_groups
total_stateless_ports = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3
num_dp_master_ports = 5
from vllm.distributed.utils import get_cached_tcp_store_client
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
store = get_cached_tcp_store_client(
self.data_parallel_master_ip, self._coord_store_port
)
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
all_ports = all_ports[:-num_dp_master_ports]
self._stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
self._stateless_dp_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
self._stateless_ep_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
self._stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop()
def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop()
def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop()
def get_next_stateless_eplb_group_port(self) -> list[int]:
return self._stateless_eplb_group_port_list.pop()
key = "dp_master_port"
if self.data_parallel_rank == 0:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((self.data_parallel_master_ip, 0))
s.listen()
port = s.getsockname()[1]
store.set(key, str(port).encode())
return port, s
else:
return int(store.get(key).decode()), None
@overload
def stateless_init_dp_group(
@@ -553,14 +497,16 @@ class ParallelConfig:
last_exc: Exception | None = None
for _ in range(max_retries):
try:
port, listen_socket = self._pick_stateless_dp_port()
# use gloo since the engine process might not have cuda device
return stateless_init_torch_distributed_process_group(
self.data_parallel_master_ip,
self.get_next_dp_init_port(),
port,
self.data_parallel_rank,
self.data_parallel_size,
backend="gloo",
return_store=return_store,
listen_socket=listen_socket,
)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.

View File

@@ -162,10 +162,8 @@ class ElasticEPScalingExecutor:
new_dp_size=new_dp_size,
new_world_size_across_dp=new_world_size_across_dp,
master_ip=reconfig_request.new_data_parallel_master_ip,
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
coord_store_port=reconfig_request.coord_store_port,
enable_eplb=updated_config.parallel_config.enable_eplb,
)
self.worker.model_runner.eep_eplb_suppressed = True
standby_ep_group = get_standby_ep_group()

View File

@@ -563,15 +563,4 @@ class ElasticEPScalingState:
parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
parallel_config._stateless_world_group_port_list = (
reconfig_request.new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = (
reconfig_request.new_stateless_dp_group_port_list
)
parallel_config._stateless_ep_group_port_list = (
reconfig_request.new_stateless_ep_group_port_list
)
parallel_config._stateless_eplb_group_port_list = (
reconfig_request.new_stateless_eplb_group_port_list
)
parallel_config._coord_store_port = reconfig_request.coord_store_port

View File

@@ -38,10 +38,8 @@ def create_standby_groups(
new_dp_size: int,
new_world_size_across_dp: int,
master_ip: str,
world_group_ports: list[list[int]],
dp_group_ports: list[list[int]],
ep_group_ports: list[list[int]],
eplb_group_ports: list[list[int]] | None = None,
coord_store_port: int,
enable_eplb: bool = True,
backend: str | None = None,
) -> None:
global \
@@ -51,19 +49,23 @@ def create_standby_groups(
_STANDBY_EP, \
_STANDBY_EPLB
from vllm.distributed.utils import get_cached_tcp_store_client
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
world_group = get_world_group()
assert isinstance(world_group, StatelessGroupCoordinator)
backend = backend or world_group.backend
coord_store = get_cached_tcp_store_client(master_ip, coord_store_port)
standby_world_ranks = [list(range(new_world_size_across_dp))]
_STANDBY_WORLD = _init_stateless_group(
standby_world_ranks,
"world",
world_group_ports,
master_ip,
backend,
use_device_communicator=False,
coord_store=coord_store,
)
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
@@ -76,7 +78,7 @@ def create_standby_groups(
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
_STANDBY_DP = _init_stateless_group(
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
standby_dp_ranks, "dp", master_ip, backend, coord_store=coord_store
)
standby_ep_ranks = (
@@ -84,12 +86,16 @@ def create_standby_groups(
)
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
_STANDBY_EP = _init_stateless_group(
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
standby_ep_ranks, "ep", master_ip, backend, coord_store=coord_store
)
if eplb_group_ports is not None:
if enable_eplb:
_STANDBY_EPLB = _init_stateless_group(
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
standby_ep_ranks,
"eplb",
master_ip,
backend,
coord_store=coord_store,
)

View File

@@ -40,13 +40,16 @@ import torch
import torch.distributed
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
from torch.distributed import Backend, ProcessGroup
from torch.distributed import Backend, ProcessGroup, Store
import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase,
)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.utils import (
StatelessProcessGroup,
get_cached_tcp_store_client,
)
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.network_utils import get_distributed_init_method
@@ -1164,9 +1167,9 @@ def init_model_parallel_group(
def _init_stateless_group(
group_ranks: list[list[int]],
group_name: str,
group_ports: list[list[int]],
host: str,
backend: str,
coord_store: Store,
use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator":
"""Create a StatelessGroupCoordinator with the given parameters."""
@@ -1180,7 +1183,7 @@ def _init_stateless_group(
use_device_communicator=use_device_communicator,
group_name=group_name,
host=host,
group_ports=group_ports,
coord_store=coord_store,
global_rank=world.rank,
global_world_size=world.world_size,
)
@@ -1321,7 +1324,9 @@ def _init_elastic_ep_world(
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
if global_rank in all_ranks:
group_ranks = [all_ranks]
group_ports = [parallel_config.get_next_stateless_world_group_port()]
coord_store = get_cached_tcp_store_client(
parallel_config.data_parallel_master_ip, parallel_config._coord_store_port
)
world = StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
@@ -1329,7 +1334,7 @@ def _init_elastic_ep_world(
use_device_communicator=False,
group_name="world",
host=parallel_config.data_parallel_master_ip,
group_ports=group_ports,
coord_store=coord_store,
global_rank=global_rank,
global_world_size=global_world_size,
)
@@ -1513,7 +1518,13 @@ def initialize_model_parallel(
config = get_current_vllm_config()
data_parallel_size = config.parallel_config.data_parallel_size
enable_elastic_ep = config.parallel_config.enable_elastic_ep
parallel_config = config.parallel_config
coord_store: Store | None = None
if enable_elastic_ep:
coord_store = get_cached_tcp_store_client(
parallel_config.data_parallel_master_ip,
parallel_config._coord_store_port,
)
# Use stateless world group for global information
world_size = get_world_group().world_size
rank = get_world_group().rank
@@ -1633,16 +1644,12 @@ def initialize_model_parallel(
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
parallel_config = config.parallel_config
dp_ports = [
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
]
_DP = _init_stateless_group(
group_ranks,
"dp",
dp_ports,
parallel_config.data_parallel_master_ip,
backend,
coord_store=coord_store,
)
else:
_DP = init_model_parallel_group(
@@ -1665,16 +1672,12 @@ def initialize_model_parallel(
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
parallel_config = config.parallel_config
ep_ports = [
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
]
_EP = _init_stateless_group(
group_ranks,
"ep",
ep_ports,
parallel_config.data_parallel_master_ip,
backend,
coord_store=coord_store,
)
else:
_EP = init_model_parallel_group(
@@ -1693,16 +1696,12 @@ def initialize_model_parallel(
and config.parallel_config.enable_eplb
):
if enable_elastic_ep:
eplb_ports = [
parallel_config.get_next_stateless_eplb_group_port()
for _ in group_ranks
]
_EPLB = _init_stateless_group(
group_ranks,
"eplb",
eplb_ports,
parallel_config.data_parallel_master_ip,
backend,
coord_store=coord_store,
)
else:
_EPLB = init_model_parallel_group(

View File

@@ -1,9 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import socket
import struct
from typing import Any, Optional
import torch
from torch.distributed import Backend, ProcessGroup
from torch.distributed import Backend, ProcessGroup, Store
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.parallel_state import (
@@ -23,6 +25,38 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
_PORTS_FMT = "!3I"
def _allocate_group_ports(
key: str,
host: str,
coord_store: Store,
) -> tuple[list[int], list[socket.socket]]:
"""Bind 3 sockets and publish the ports to *coord_store*.
Called by rank 0 only. Returns ``(ports, sockets)`` with the
sockets still open.
"""
socks: list[socket.socket] = []
ports: list[int] = []
for _ in range(3):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((host, 0))
s.listen()
socks.append(s)
ports.append(s.getsockname()[1])
coord_store.set(key, struct.pack(_PORTS_FMT, *ports))
return ports, socks
def _fetch_group_ports(key: str, coord_store: Store) -> list[int]:
"""Read 3 ports published by rank 0 from *coord_store*.
Blocks until the key is available.
"""
return list(struct.unpack(_PORTS_FMT, coord_store.get(key)))
class StatelessGroupCoordinator(GroupCoordinator):
"""
@@ -39,10 +73,10 @@ class StatelessGroupCoordinator(GroupCoordinator):
local_rank: int,
torch_distributed_backend: str | Backend,
use_device_communicator: bool,
coord_store: Store,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
host: str = "127.0.0.1",
group_ports: list[list[int]] | None = None,
global_rank: int = 0,
global_world_size: int = 1,
):
@@ -61,17 +95,23 @@ class StatelessGroupCoordinator(GroupCoordinator):
backend = str(torch_distributed_backend)
self.backend = backend
assert group_ports is not None, "group_ports is not provided"
for idx, ranks in enumerate(group_ranks):
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
ports = group_ports[idx]
device_port = ports[0]
cpu_port = ports[1]
tcp_store_port = ports[2]
key = f"{group_name}_{idx}"
if self.rank_in_group == 0:
ports, socks = _allocate_group_ports(
key,
host,
coord_store,
)
else:
ports = _fetch_group_ports(key, coord_store)
socks = []
device_port, cpu_port, tcp_store_port = ports
device_group = stateless_init_torch_distributed_process_group(
host=host,
@@ -80,6 +120,7 @@ class StatelessGroupCoordinator(GroupCoordinator):
world_size=self.world_size,
backend=backend,
group_name=f"{self.unique_name}_device",
listen_socket=socks[0] if socks else None,
)
cpu_group = stateless_init_torch_distributed_process_group(
host=host,
@@ -88,12 +129,14 @@ class StatelessGroupCoordinator(GroupCoordinator):
world_size=self.world_size,
backend="gloo",
group_name=f"{self.unique_name}_cpu",
listen_socket=socks[1] if socks else None,
)
tcp_store_group = StatelessProcessGroup.create(
host=host,
port=tcp_store_port,
rank=self.rank_in_group,
world_size=self.world_size,
listen_socket=socks[2] if socks else None,
)
self_device_group = device_group

View File

@@ -6,6 +6,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import functools
import os
import pickle
import socket
@@ -139,6 +140,29 @@ def get_pp_indices(
return (start_layer, end_layer)
def create_tcp_store(
host: str,
port: int,
listen_socket: socket.socket | None = None,
**kwargs: Any,
) -> TCPStore:
"""Create a TCPStore, optionally taking ownership of ``listen_socket``."""
if listen_socket is None:
return TCPStore(host_name=host, port=port, **kwargs)
listen_fd = listen_socket.detach()
try:
return TCPStore(
host_name=host,
port=port,
master_listen_fd=listen_fd,
**kwargs,
)
except Exception:
socket.close(listen_fd)
raise
@dataclasses.dataclass
class StatelessProcessGroup:
"""A dataclass to hold a metadata store, and the rank, world_size of the
@@ -150,9 +174,6 @@ class StatelessProcessGroup:
world_size: int
store: torch._C._distributed_c10d.Store
# stores a reference to the socket so that the file descriptor stays alive
socket: socket.socket | None
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
@@ -419,6 +440,7 @@ class StatelessProcessGroup:
world_size: int,
data_expiration_seconds: int = 3600,
store_timeout: int = 300,
listen_socket: socket.socket | None = None,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
@@ -436,36 +458,39 @@ class StatelessProcessGroup:
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
launch_server = rank == 0
if launch_server:
# listen on the specified interface (instead of 0.0.0.0)
if launch_server and listen_socket is None:
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind((host, port))
listen_socket.listen()
listen_fd = listen_socket.fileno()
else:
listen_socket = None
listen_fd = None
store = TCPStore(
host_name=host,
port=port,
store = create_tcp_store(
host,
port,
listen_socket=listen_socket,
world_size=world_size,
is_master=launch_server,
timeout=timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd,
)
return StatelessProcessGroup(
rank=rank,
world_size=world_size,
store=store,
socket=listen_socket,
data_expiration_seconds=data_expiration_seconds,
)
@functools.lru_cache(maxsize=1)
def get_cached_tcp_store_client(host: str, port: int) -> TCPStore:
"""Return a cached TCPStore client.
Cached so that every call with the same ``(host, port)`` reuses the
same connection. A new ``(host, port)`` evicts the old entry.
"""
return TCPStore(host, port, is_master=False, wait_for_workers=False)
def init_gloo_process_group(
prefix_store: PrefixStore,
group_rank: int,
@@ -504,6 +529,7 @@ def stateless_init_torch_distributed_process_group(
backend: str,
group_name: str | None = None,
return_store: bool = False,
listen_socket: socket.socket | None = None,
) -> ProcessGroup | tuple[ProcessGroup, Store]:
"""
A replacement for `torch.distributed.init_process_group` that does not
@@ -535,14 +561,30 @@ def stateless_init_torch_distributed_process_group(
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
When *listen_socket* is provided, the rendezvous step
is skipped and a ``TCPStore`` server is created directly using the
pre-bound socket. This is useful for eliminating TOCTOU races
between port allocation and binding.
"""
init_method = get_tcp_uri(host, port)
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout)
)
if listen_socket is not None:
store = create_tcp_store(
host,
port,
listen_socket=listen_socket,
world_size=world_size,
is_master=True,
timeout=timeout,
multi_tenant=True,
)
else:
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout)
)
store.set_timeout(timeout)
group_rank = rank

View File

@@ -237,10 +237,7 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_master_ip: str
new_data_parallel_master_port: int
new_data_parallel_master_port_list: list[int]
new_stateless_world_group_port_list: list[list[int]]
new_stateless_dp_group_port_list: list[list[int]]
new_stateless_ep_group_port_list: list[list[int]]
new_stateless_eplb_group_port_list: list[list[int]]
coord_store_port: int
class ReconfigureRankType(enum.IntEnum):

View File

@@ -1767,6 +1767,7 @@ class DPEngineCoreProc(EngineCoreProc):
new_parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
new_parallel_config._coord_store_port = reconfig_request.coord_store_port
is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
is_shutdown = (

View File

@@ -455,56 +455,6 @@ class ElasticScalingCache:
pending_notifications: dict[EEPNotificationType, set[int]]
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
"""
Allocate stateless group ports for elastic EP.
"""
from vllm.utils.network_utils import get_open_ports_list
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
world_size = parallel_config.world_size
new_world_size_across_dp = world_size * new_data_parallel_size
num_world_groups = 1
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
num_ep_groups = max(
1,
new_world_size_across_dp
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
)
num_eplb_groups = num_ep_groups
total_ports_needed = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3 + 5
all_ports = get_open_ports_list(total_ports_needed)
new_data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
new_stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
new_stateless_dp_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
new_stateless_ep_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
new_stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
parallel_config._stateless_world_group_port_list = (
new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
class MPClient(EngineCoreClient):
"""
MPClient: base client for multi-proc EngineCore.
@@ -1541,6 +1491,28 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
self._ensure_output_queue_task()
await future
def _setup_elastic_ep_reconfig_bootstrap(self) -> tuple[str, int]:
from vllm.distributed.utils import create_tcp_store
from vllm.utils.network_utils import get_open_ports_list
parallel_config = self.vllm_config.parallel_config
parallel_config._data_parallel_master_port_list = get_open_ports_list(5)
parallel_config.data_parallel_master_port = (
parallel_config._data_parallel_master_port_list.pop()
)
ip = parallel_config.data_parallel_master_ip
store = create_tcp_store(
ip,
0,
is_master=True,
world_size=-1,
wait_for_workers=False,
)
parallel_config._coord_store_port = store.port
self._coord_store = store
return ip, store.port
async def _scale_up_elastic_ep(
self, cur_data_parallel_size: int, new_data_parallel_size: int
) -> None:
@@ -1555,7 +1527,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
# Phase 1: Send reconfig messages to existing engines
reconfig_futures = []
@@ -1564,13 +1536,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_ip=ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
coord_store_port=coord_store_port,
)
coro = self._call_utility_async(
"reinitialize_distributed", reconfig_request, engine=engine
@@ -1650,7 +1619,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines):
@@ -1658,13 +1627,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_ip=ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
coord_store_port=coord_store_port,
)
if cur_dp_rank >= new_data_parallel_size:
reconfig_request.new_data_parallel_rank = (

View File

@@ -301,7 +301,20 @@ class CoreEngineActorManager:
else:
ray.init()
vllm_config.parallel_config.allocate_elastic_ep_ports()
parallel_config = vllm_config.parallel_config
if parallel_config.enable_elastic_ep:
from vllm.distributed.utils import create_tcp_store
ip = parallel_config.data_parallel_master_ip
store = create_tcp_store(
ip,
0,
is_master=True,
world_size=-1,
wait_for_workers=False,
)
parallel_config._coord_store_port = store.port
self._coord_store = store
if placement_groups is not None:
assert local_dp_ranks is not None, (