elastic_ep: Fix stateless group port races (#36330)
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
This commit is contained in:
@@ -24,8 +24,7 @@ steps:
|
||||
|
||||
- label: Elastic EP Scaling Test
|
||||
timeout_in_minutes: 20
|
||||
device: b200
|
||||
optional: true
|
||||
device: h100
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_devices: 4
|
||||
source_file_dependencies:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import socket
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
@@ -266,33 +267,9 @@ class ParallelConfig:
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
"""
|
||||
|
||||
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
It is a list of list[int], with each inner list contains a set of 3 ports
|
||||
to be used for setting up the stateless CPU/device/TCPStore groups
|
||||
in StatelessGroupCoordinator. The number of inner lists is equal to
|
||||
the number of DP groups,
|
||||
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
|
||||
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
|
||||
"""
|
||||
|
||||
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
|
||||
"""
|
||||
|
||||
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
|
||||
Same topology as EP but separate NCCL communicator to avoid deadlocks.
|
||||
"""
|
||||
|
||||
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless world group when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
len(self._stateless_world_group_port_list) == 1,
|
||||
"""
|
||||
_coord_store_port: int = 0
|
||||
"""Port of the coordination TCPStore. Can be set by the API server; workers
|
||||
connect as clients to exchange self-picked group ports at runtime."""
|
||||
|
||||
decode_context_parallel_size: int = 1
|
||||
"""Number of decode context parallel groups, because the world size does
|
||||
@@ -465,65 +442,32 @@ class ParallelConfig:
|
||||
|
||||
return answer
|
||||
|
||||
def allocate_elastic_ep_ports(self) -> None:
|
||||
"""Allocate all ports for elastic EP (stateless groups + DP master).
|
||||
def _pick_stateless_dp_port(self) -> tuple[int, socket.socket | None]:
|
||||
"""Return ``(port, listen_socket)`` for DP group init.
|
||||
|
||||
Must be called AFTER ray.init() so that ports claimed by Ray's
|
||||
idle worker pool are already in use and won't be returned by
|
||||
get_open_ports_list().
|
||||
With a coord store, rank 0 binds a socket and publishes the port;
|
||||
others read it. Without one, pops a pre-allocated port and
|
||||
returns ``listen_socket=None``.
|
||||
"""
|
||||
if not self.enable_elastic_ep:
|
||||
return
|
||||
if self._stateless_world_group_port_list:
|
||||
return
|
||||
if not self._coord_store_port:
|
||||
return self.get_next_dp_init_port(), None
|
||||
|
||||
num_world_groups = 1
|
||||
dp_size = self.data_parallel_size
|
||||
ep_size = self.data_parallel_size * self.world_size_across_dp
|
||||
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
|
||||
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
|
||||
num_eplb_groups = num_ep_groups
|
||||
total_stateless_ports = (
|
||||
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
|
||||
) * 3
|
||||
num_dp_master_ports = 5
|
||||
from vllm.distributed.utils import get_cached_tcp_store_client
|
||||
|
||||
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
|
||||
store = get_cached_tcp_store_client(
|
||||
self.data_parallel_master_ip, self._coord_store_port
|
||||
)
|
||||
|
||||
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
|
||||
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
|
||||
all_ports = all_ports[:-num_dp_master_ports]
|
||||
|
||||
self._stateless_world_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
|
||||
]
|
||||
start_idx = num_world_groups * 3
|
||||
self._stateless_dp_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_dp_groups * 3
|
||||
self._stateless_ep_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_ep_groups * 3
|
||||
self._stateless_eplb_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
|
||||
]
|
||||
|
||||
def get_next_stateless_world_group_port(self) -> list[int]:
|
||||
return self._stateless_world_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_dp_group_port(self) -> list[int]:
|
||||
return self._stateless_dp_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_ep_group_port(self) -> list[int]:
|
||||
return self._stateless_ep_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_eplb_group_port(self) -> list[int]:
|
||||
return self._stateless_eplb_group_port_list.pop()
|
||||
key = "dp_master_port"
|
||||
if self.data_parallel_rank == 0:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind((self.data_parallel_master_ip, 0))
|
||||
s.listen()
|
||||
port = s.getsockname()[1]
|
||||
store.set(key, str(port).encode())
|
||||
return port, s
|
||||
else:
|
||||
return int(store.get(key).decode()), None
|
||||
|
||||
@overload
|
||||
def stateless_init_dp_group(
|
||||
@@ -553,14 +497,16 @@ class ParallelConfig:
|
||||
last_exc: Exception | None = None
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
port, listen_socket = self._pick_stateless_dp_port()
|
||||
# use gloo since the engine process might not have cuda device
|
||||
return stateless_init_torch_distributed_process_group(
|
||||
self.data_parallel_master_ip,
|
||||
self.get_next_dp_init_port(),
|
||||
port,
|
||||
self.data_parallel_rank,
|
||||
self.data_parallel_size,
|
||||
backend="gloo",
|
||||
return_store=return_store,
|
||||
listen_socket=listen_socket,
|
||||
)
|
||||
except DistNetworkError as e:
|
||||
# We only want to retry when the root cause is EADDRINUSE.
|
||||
|
||||
@@ -162,10 +162,8 @@ class ElasticEPScalingExecutor:
|
||||
new_dp_size=new_dp_size,
|
||||
new_world_size_across_dp=new_world_size_across_dp,
|
||||
master_ip=reconfig_request.new_data_parallel_master_ip,
|
||||
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
|
||||
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
|
||||
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
|
||||
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
|
||||
coord_store_port=reconfig_request.coord_store_port,
|
||||
enable_eplb=updated_config.parallel_config.enable_eplb,
|
||||
)
|
||||
self.worker.model_runner.eep_eplb_suppressed = True
|
||||
standby_ep_group = get_standby_ep_group()
|
||||
|
||||
@@ -563,15 +563,4 @@ class ElasticEPScalingState:
|
||||
parallel_config._data_parallel_master_port_list = (
|
||||
reconfig_request.new_data_parallel_master_port_list
|
||||
)
|
||||
parallel_config._stateless_world_group_port_list = (
|
||||
reconfig_request.new_stateless_world_group_port_list
|
||||
)
|
||||
parallel_config._stateless_dp_group_port_list = (
|
||||
reconfig_request.new_stateless_dp_group_port_list
|
||||
)
|
||||
parallel_config._stateless_ep_group_port_list = (
|
||||
reconfig_request.new_stateless_ep_group_port_list
|
||||
)
|
||||
parallel_config._stateless_eplb_group_port_list = (
|
||||
reconfig_request.new_stateless_eplb_group_port_list
|
||||
)
|
||||
parallel_config._coord_store_port = reconfig_request.coord_store_port
|
||||
|
||||
@@ -38,10 +38,8 @@ def create_standby_groups(
|
||||
new_dp_size: int,
|
||||
new_world_size_across_dp: int,
|
||||
master_ip: str,
|
||||
world_group_ports: list[list[int]],
|
||||
dp_group_ports: list[list[int]],
|
||||
ep_group_ports: list[list[int]],
|
||||
eplb_group_ports: list[list[int]] | None = None,
|
||||
coord_store_port: int,
|
||||
enable_eplb: bool = True,
|
||||
backend: str | None = None,
|
||||
) -> None:
|
||||
global \
|
||||
@@ -51,19 +49,23 @@ def create_standby_groups(
|
||||
_STANDBY_EP, \
|
||||
_STANDBY_EPLB
|
||||
|
||||
from vllm.distributed.utils import get_cached_tcp_store_client
|
||||
|
||||
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
|
||||
world_group = get_world_group()
|
||||
assert isinstance(world_group, StatelessGroupCoordinator)
|
||||
backend = backend or world_group.backend
|
||||
|
||||
coord_store = get_cached_tcp_store_client(master_ip, coord_store_port)
|
||||
|
||||
standby_world_ranks = [list(range(new_world_size_across_dp))]
|
||||
_STANDBY_WORLD = _init_stateless_group(
|
||||
standby_world_ranks,
|
||||
"world",
|
||||
world_group_ports,
|
||||
master_ip,
|
||||
backend,
|
||||
use_device_communicator=False,
|
||||
coord_store=coord_store,
|
||||
)
|
||||
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
|
||||
|
||||
@@ -76,7 +78,7 @@ def create_standby_groups(
|
||||
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
|
||||
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
|
||||
_STANDBY_DP = _init_stateless_group(
|
||||
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
|
||||
standby_dp_ranks, "dp", master_ip, backend, coord_store=coord_store
|
||||
)
|
||||
|
||||
standby_ep_ranks = (
|
||||
@@ -84,12 +86,16 @@ def create_standby_groups(
|
||||
)
|
||||
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
|
||||
_STANDBY_EP = _init_stateless_group(
|
||||
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
|
||||
standby_ep_ranks, "ep", master_ip, backend, coord_store=coord_store
|
||||
)
|
||||
|
||||
if eplb_group_ports is not None:
|
||||
if enable_eplb:
|
||||
_STANDBY_EPLB = _init_stateless_group(
|
||||
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
|
||||
standby_ep_ranks,
|
||||
"eplb",
|
||||
master_ip,
|
||||
backend,
|
||||
coord_store=coord_store,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -40,13 +40,16 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed._symmetric_memory
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
from torch.distributed import Backend, ProcessGroup, Store
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase,
|
||||
)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.distributed.utils import (
|
||||
StatelessProcessGroup,
|
||||
get_cached_tcp_store_client,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.network_utils import get_distributed_init_method
|
||||
@@ -1164,9 +1167,9 @@ def init_model_parallel_group(
|
||||
def _init_stateless_group(
|
||||
group_ranks: list[list[int]],
|
||||
group_name: str,
|
||||
group_ports: list[list[int]],
|
||||
host: str,
|
||||
backend: str,
|
||||
coord_store: Store,
|
||||
use_device_communicator: bool = True,
|
||||
) -> "StatelessGroupCoordinator":
|
||||
"""Create a StatelessGroupCoordinator with the given parameters."""
|
||||
@@ -1180,7 +1183,7 @@ def _init_stateless_group(
|
||||
use_device_communicator=use_device_communicator,
|
||||
group_name=group_name,
|
||||
host=host,
|
||||
group_ports=group_ports,
|
||||
coord_store=coord_store,
|
||||
global_rank=world.rank,
|
||||
global_world_size=world.world_size,
|
||||
)
|
||||
@@ -1321,7 +1324,9 @@ def _init_elastic_ep_world(
|
||||
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
|
||||
if global_rank in all_ranks:
|
||||
group_ranks = [all_ranks]
|
||||
group_ports = [parallel_config.get_next_stateless_world_group_port()]
|
||||
coord_store = get_cached_tcp_store_client(
|
||||
parallel_config.data_parallel_master_ip, parallel_config._coord_store_port
|
||||
)
|
||||
world = StatelessGroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
@@ -1329,7 +1334,7 @@ def _init_elastic_ep_world(
|
||||
use_device_communicator=False,
|
||||
group_name="world",
|
||||
host=parallel_config.data_parallel_master_ip,
|
||||
group_ports=group_ports,
|
||||
coord_store=coord_store,
|
||||
global_rank=global_rank,
|
||||
global_world_size=global_world_size,
|
||||
)
|
||||
@@ -1513,7 +1518,13 @@ def initialize_model_parallel(
|
||||
config = get_current_vllm_config()
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
enable_elastic_ep = config.parallel_config.enable_elastic_ep
|
||||
parallel_config = config.parallel_config
|
||||
coord_store: Store | None = None
|
||||
if enable_elastic_ep:
|
||||
coord_store = get_cached_tcp_store_client(
|
||||
parallel_config.data_parallel_master_ip,
|
||||
parallel_config._coord_store_port,
|
||||
)
|
||||
# Use stateless world group for global information
|
||||
world_size = get_world_group().world_size
|
||||
rank = get_world_group().rank
|
||||
@@ -1633,16 +1644,12 @@ def initialize_model_parallel(
|
||||
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
dp_ports = [
|
||||
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
|
||||
]
|
||||
_DP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"dp",
|
||||
dp_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
coord_store=coord_store,
|
||||
)
|
||||
else:
|
||||
_DP = init_model_parallel_group(
|
||||
@@ -1665,16 +1672,12 @@ def initialize_model_parallel(
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
ep_ports = [
|
||||
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
|
||||
]
|
||||
_EP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"ep",
|
||||
ep_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
coord_store=coord_store,
|
||||
)
|
||||
else:
|
||||
_EP = init_model_parallel_group(
|
||||
@@ -1693,16 +1696,12 @@ def initialize_model_parallel(
|
||||
and config.parallel_config.enable_eplb
|
||||
):
|
||||
if enable_elastic_ep:
|
||||
eplb_ports = [
|
||||
parallel_config.get_next_stateless_eplb_group_port()
|
||||
for _ in group_ranks
|
||||
]
|
||||
_EPLB = _init_stateless_group(
|
||||
group_ranks,
|
||||
"eplb",
|
||||
eplb_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
coord_store=coord_store,
|
||||
)
|
||||
else:
|
||||
_EPLB = init_model_parallel_group(
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import socket
|
||||
import struct
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
from torch.distributed import Backend, ProcessGroup, Store
|
||||
|
||||
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||
from vllm.distributed.parallel_state import (
|
||||
@@ -23,6 +25,38 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PORTS_FMT = "!3I"
|
||||
|
||||
|
||||
def _allocate_group_ports(
|
||||
key: str,
|
||||
host: str,
|
||||
coord_store: Store,
|
||||
) -> tuple[list[int], list[socket.socket]]:
|
||||
"""Bind 3 sockets and publish the ports to *coord_store*.
|
||||
|
||||
Called by rank 0 only. Returns ``(ports, sockets)`` with the
|
||||
sockets still open.
|
||||
"""
|
||||
socks: list[socket.socket] = []
|
||||
ports: list[int] = []
|
||||
for _ in range(3):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind((host, 0))
|
||||
s.listen()
|
||||
socks.append(s)
|
||||
ports.append(s.getsockname()[1])
|
||||
coord_store.set(key, struct.pack(_PORTS_FMT, *ports))
|
||||
return ports, socks
|
||||
|
||||
|
||||
def _fetch_group_ports(key: str, coord_store: Store) -> list[int]:
|
||||
"""Read 3 ports published by rank 0 from *coord_store*.
|
||||
|
||||
Blocks until the key is available.
|
||||
"""
|
||||
return list(struct.unpack(_PORTS_FMT, coord_store.get(key)))
|
||||
|
||||
|
||||
class StatelessGroupCoordinator(GroupCoordinator):
|
||||
"""
|
||||
@@ -39,10 +73,10 @@ class StatelessGroupCoordinator(GroupCoordinator):
|
||||
local_rank: int,
|
||||
torch_distributed_backend: str | Backend,
|
||||
use_device_communicator: bool,
|
||||
coord_store: Store,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: str | None = None,
|
||||
host: str = "127.0.0.1",
|
||||
group_ports: list[list[int]] | None = None,
|
||||
global_rank: int = 0,
|
||||
global_world_size: int = 1,
|
||||
):
|
||||
@@ -61,17 +95,23 @@ class StatelessGroupCoordinator(GroupCoordinator):
|
||||
|
||||
backend = str(torch_distributed_backend)
|
||||
self.backend = backend
|
||||
assert group_ports is not None, "group_ports is not provided"
|
||||
for idx, ranks in enumerate(group_ranks):
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
|
||||
ports = group_ports[idx]
|
||||
device_port = ports[0]
|
||||
cpu_port = ports[1]
|
||||
tcp_store_port = ports[2]
|
||||
key = f"{group_name}_{idx}"
|
||||
if self.rank_in_group == 0:
|
||||
ports, socks = _allocate_group_ports(
|
||||
key,
|
||||
host,
|
||||
coord_store,
|
||||
)
|
||||
else:
|
||||
ports = _fetch_group_ports(key, coord_store)
|
||||
socks = []
|
||||
device_port, cpu_port, tcp_store_port = ports
|
||||
|
||||
device_group = stateless_init_torch_distributed_process_group(
|
||||
host=host,
|
||||
@@ -80,6 +120,7 @@ class StatelessGroupCoordinator(GroupCoordinator):
|
||||
world_size=self.world_size,
|
||||
backend=backend,
|
||||
group_name=f"{self.unique_name}_device",
|
||||
listen_socket=socks[0] if socks else None,
|
||||
)
|
||||
cpu_group = stateless_init_torch_distributed_process_group(
|
||||
host=host,
|
||||
@@ -88,12 +129,14 @@ class StatelessGroupCoordinator(GroupCoordinator):
|
||||
world_size=self.world_size,
|
||||
backend="gloo",
|
||||
group_name=f"{self.unique_name}_cpu",
|
||||
listen_socket=socks[1] if socks else None,
|
||||
)
|
||||
tcp_store_group = StatelessProcessGroup.create(
|
||||
host=host,
|
||||
port=tcp_store_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
listen_socket=socks[2] if socks else None,
|
||||
)
|
||||
|
||||
self_device_group = device_group
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import dataclasses
|
||||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
@@ -139,6 +140,29 @@ def get_pp_indices(
|
||||
return (start_layer, end_layer)
|
||||
|
||||
|
||||
def create_tcp_store(
|
||||
host: str,
|
||||
port: int,
|
||||
listen_socket: socket.socket | None = None,
|
||||
**kwargs: Any,
|
||||
) -> TCPStore:
|
||||
"""Create a TCPStore, optionally taking ownership of ``listen_socket``."""
|
||||
if listen_socket is None:
|
||||
return TCPStore(host_name=host, port=port, **kwargs)
|
||||
|
||||
listen_fd = listen_socket.detach()
|
||||
try:
|
||||
return TCPStore(
|
||||
host_name=host,
|
||||
port=port,
|
||||
master_listen_fd=listen_fd,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception:
|
||||
socket.close(listen_fd)
|
||||
raise
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StatelessProcessGroup:
|
||||
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
||||
@@ -150,9 +174,6 @@ class StatelessProcessGroup:
|
||||
world_size: int
|
||||
store: torch._C._distributed_c10d.Store
|
||||
|
||||
# stores a reference to the socket so that the file descriptor stays alive
|
||||
socket: socket.socket | None
|
||||
|
||||
data_expiration_seconds: int = 3600 # 1 hour
|
||||
|
||||
# dst rank -> counter
|
||||
@@ -419,6 +440,7 @@ class StatelessProcessGroup:
|
||||
world_size: int,
|
||||
data_expiration_seconds: int = 3600,
|
||||
store_timeout: int = 300,
|
||||
listen_socket: socket.socket | None = None,
|
||||
) -> "StatelessProcessGroup":
|
||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state.
|
||||
@@ -436,36 +458,39 @@ class StatelessProcessGroup:
|
||||
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||
""" # noqa
|
||||
launch_server = rank == 0
|
||||
if launch_server:
|
||||
# listen on the specified interface (instead of 0.0.0.0)
|
||||
if launch_server and listen_socket is None:
|
||||
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
listen_socket.bind((host, port))
|
||||
listen_socket.listen()
|
||||
listen_fd = listen_socket.fileno()
|
||||
else:
|
||||
listen_socket = None
|
||||
listen_fd = None
|
||||
|
||||
store = TCPStore(
|
||||
host_name=host,
|
||||
port=port,
|
||||
store = create_tcp_store(
|
||||
host,
|
||||
port,
|
||||
listen_socket=listen_socket,
|
||||
world_size=world_size,
|
||||
is_master=launch_server,
|
||||
timeout=timedelta(seconds=store_timeout),
|
||||
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
|
||||
master_listen_fd=listen_fd,
|
||||
)
|
||||
|
||||
return StatelessProcessGroup(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
socket=listen_socket,
|
||||
data_expiration_seconds=data_expiration_seconds,
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_cached_tcp_store_client(host: str, port: int) -> TCPStore:
|
||||
"""Return a cached TCPStore client.
|
||||
|
||||
Cached so that every call with the same ``(host, port)`` reuses the
|
||||
same connection. A new ``(host, port)`` evicts the old entry.
|
||||
"""
|
||||
return TCPStore(host, port, is_master=False, wait_for_workers=False)
|
||||
|
||||
|
||||
def init_gloo_process_group(
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
@@ -504,6 +529,7 @@ def stateless_init_torch_distributed_process_group(
|
||||
backend: str,
|
||||
group_name: str | None = None,
|
||||
return_store: bool = False,
|
||||
listen_socket: socket.socket | None = None,
|
||||
) -> ProcessGroup | tuple[ProcessGroup, Store]:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
@@ -535,14 +561,30 @@ def stateless_init_torch_distributed_process_group(
|
||||
are the same as process 1 and 5, the main communication channel is
|
||||
always formed with process 1, 2, ..., 8, and the additional communication
|
||||
channel is formed with process 9 and 10.
|
||||
|
||||
When *listen_socket* is provided, the rendezvous step
|
||||
is skipped and a ``TCPStore`` server is created directly using the
|
||||
pre-bound socket. This is useful for eliminating TOCTOU races
|
||||
between port allocation and binding.
|
||||
"""
|
||||
init_method = get_tcp_uri(host, port)
|
||||
backend = Backend(backend) # it is basically string
|
||||
timeout = _get_default_timeout(backend)
|
||||
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||
)
|
||||
if listen_socket is not None:
|
||||
store = create_tcp_store(
|
||||
host,
|
||||
port,
|
||||
listen_socket=listen_socket,
|
||||
world_size=world_size,
|
||||
is_master=True,
|
||||
timeout=timeout,
|
||||
multi_tenant=True,
|
||||
)
|
||||
else:
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||
)
|
||||
store.set_timeout(timeout)
|
||||
|
||||
group_rank = rank
|
||||
|
||||
@@ -237,10 +237,7 @@ class ReconfigureDistributedRequest(msgspec.Struct):
|
||||
new_data_parallel_master_ip: str
|
||||
new_data_parallel_master_port: int
|
||||
new_data_parallel_master_port_list: list[int]
|
||||
new_stateless_world_group_port_list: list[list[int]]
|
||||
new_stateless_dp_group_port_list: list[list[int]]
|
||||
new_stateless_ep_group_port_list: list[list[int]]
|
||||
new_stateless_eplb_group_port_list: list[list[int]]
|
||||
coord_store_port: int
|
||||
|
||||
|
||||
class ReconfigureRankType(enum.IntEnum):
|
||||
|
||||
@@ -1767,6 +1767,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
new_parallel_config._data_parallel_master_port_list = (
|
||||
reconfig_request.new_data_parallel_master_port_list
|
||||
)
|
||||
new_parallel_config._coord_store_port = reconfig_request.coord_store_port
|
||||
|
||||
is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
|
||||
is_shutdown = (
|
||||
|
||||
@@ -455,56 +455,6 @@ class ElasticScalingCache:
|
||||
pending_notifications: dict[EEPNotificationType, set[int]]
|
||||
|
||||
|
||||
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
|
||||
"""
|
||||
Allocate stateless group ports for elastic EP.
|
||||
"""
|
||||
from vllm.utils.network_utils import get_open_ports_list
|
||||
|
||||
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
|
||||
world_size = parallel_config.world_size
|
||||
new_world_size_across_dp = world_size * new_data_parallel_size
|
||||
num_world_groups = 1
|
||||
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
|
||||
num_ep_groups = max(
|
||||
1,
|
||||
new_world_size_across_dp
|
||||
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
|
||||
)
|
||||
num_eplb_groups = num_ep_groups
|
||||
total_ports_needed = (
|
||||
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
|
||||
) * 3 + 5
|
||||
all_ports = get_open_ports_list(total_ports_needed)
|
||||
new_data_parallel_master_port_list = all_ports[-5:]
|
||||
all_ports = all_ports[:-5]
|
||||
new_stateless_world_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
|
||||
]
|
||||
start_idx = num_world_groups * 3
|
||||
new_stateless_dp_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_dp_groups * 3
|
||||
new_stateless_ep_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_ep_groups * 3
|
||||
new_stateless_eplb_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
|
||||
]
|
||||
|
||||
parallel_config._stateless_world_group_port_list = (
|
||||
new_stateless_world_group_port_list
|
||||
)
|
||||
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
|
||||
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
|
||||
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
|
||||
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
|
||||
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
|
||||
|
||||
|
||||
class MPClient(EngineCoreClient):
|
||||
"""
|
||||
MPClient: base client for multi-proc EngineCore.
|
||||
@@ -1541,6 +1491,28 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
self._ensure_output_queue_task()
|
||||
await future
|
||||
|
||||
def _setup_elastic_ep_reconfig_bootstrap(self) -> tuple[str, int]:
|
||||
from vllm.distributed.utils import create_tcp_store
|
||||
from vllm.utils.network_utils import get_open_ports_list
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config._data_parallel_master_port_list = get_open_ports_list(5)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
parallel_config._data_parallel_master_port_list.pop()
|
||||
)
|
||||
|
||||
ip = parallel_config.data_parallel_master_ip
|
||||
store = create_tcp_store(
|
||||
ip,
|
||||
0,
|
||||
is_master=True,
|
||||
world_size=-1,
|
||||
wait_for_workers=False,
|
||||
)
|
||||
parallel_config._coord_store_port = store.port
|
||||
self._coord_store = store
|
||||
return ip, store.port
|
||||
|
||||
async def _scale_up_elastic_ep(
|
||||
self, cur_data_parallel_size: int, new_data_parallel_size: int
|
||||
) -> None:
|
||||
@@ -1555,7 +1527,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
|
||||
ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
|
||||
|
||||
# Phase 1: Send reconfig messages to existing engines
|
||||
reconfig_futures = []
|
||||
@@ -1564,13 +1536,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_ip=ip,
|
||||
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
|
||||
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
|
||||
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
|
||||
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
|
||||
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
|
||||
coord_store_port=coord_store_port,
|
||||
)
|
||||
coro = self._call_utility_async(
|
||||
"reinitialize_distributed", reconfig_request, engine=engine
|
||||
@@ -1650,7 +1619,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
|
||||
ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
|
||||
|
||||
reconfig_futures = []
|
||||
for cur_dp_rank, engine in enumerate(self.core_engines):
|
||||
@@ -1658,13 +1627,10 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_ip=ip,
|
||||
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
|
||||
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
|
||||
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
|
||||
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
|
||||
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
|
||||
coord_store_port=coord_store_port,
|
||||
)
|
||||
if cur_dp_rank >= new_data_parallel_size:
|
||||
reconfig_request.new_data_parallel_rank = (
|
||||
|
||||
@@ -301,7 +301,20 @@ class CoreEngineActorManager:
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
vllm_config.parallel_config.allocate_elastic_ep_ports()
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.enable_elastic_ep:
|
||||
from vllm.distributed.utils import create_tcp_store
|
||||
|
||||
ip = parallel_config.data_parallel_master_ip
|
||||
store = create_tcp_store(
|
||||
ip,
|
||||
0,
|
||||
is_master=True,
|
||||
world_size=-1,
|
||||
wait_for_workers=False,
|
||||
)
|
||||
parallel_config._coord_store_port = store.port
|
||||
self._coord_store = store
|
||||
|
||||
if placement_groups is not None:
|
||||
assert local_dp_ranks is not None, (
|
||||
|
||||
Reference in New Issue
Block a user