[1/N] Elastic EP Milestone 2 (#34861)
Signed-off-by: Yongji Wu <wuyongji317@gmail.com> Signed-off-by: Itay Alroy <ialroy@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com> Co-authored-by: Yongji Wu <wuyongji317@gmail.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
@@ -165,6 +165,9 @@ class ParallelConfig:
|
||||
disable_custom_all_reduce: bool = False
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
enable_elastic_ep: bool = False
|
||||
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable dual batch overlap for the model executor."""
|
||||
ubatch_size: int = 0
|
||||
@@ -244,6 +247,34 @@ class ParallelConfig:
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
"""
|
||||
|
||||
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
It is a list of list[int], with each inner list contains a set of 3 ports
|
||||
to be used for setting up the stateless CPU/device/TCPStore groups
|
||||
in StatelessGroupCoordinator. The number of inner lists is equal to
|
||||
the number of DP groups,
|
||||
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
|
||||
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
|
||||
"""
|
||||
|
||||
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
|
||||
"""
|
||||
|
||||
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
|
||||
Same topology as EP but separate NCCL communicator to avoid deadlocks.
|
||||
"""
|
||||
|
||||
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless world group when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
len(self._stateless_world_group_port_list) == 1,
|
||||
"""
|
||||
|
||||
decode_context_parallel_size: int = 1
|
||||
"""Number of decode context parallel groups, because the world size does
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
@@ -402,7 +433,67 @@ class ParallelConfig:
|
||||
|
||||
return answer
|
||||
|
||||
def stateless_init_dp_group(self) -> ProcessGroup:
|
||||
def allocate_elastic_ep_ports(self) -> None:
|
||||
"""Allocate all ports for elastic EP (stateless groups + DP master).
|
||||
|
||||
Must be called AFTER ray.init() so that ports claimed by Ray's
|
||||
idle worker pool are already in use and won't be returned by
|
||||
get_open_ports_list().
|
||||
"""
|
||||
if not self.enable_elastic_ep:
|
||||
return
|
||||
if self._stateless_world_group_port_list:
|
||||
return
|
||||
|
||||
num_world_groups = 1
|
||||
dp_size = self.data_parallel_size
|
||||
ep_size = self.data_parallel_size * self.world_size_across_dp
|
||||
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
|
||||
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
|
||||
num_eplb_groups = num_ep_groups
|
||||
total_stateless_ports = (
|
||||
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
|
||||
) * 3
|
||||
num_dp_master_ports = 5
|
||||
|
||||
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
|
||||
|
||||
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
|
||||
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
|
||||
all_ports = all_ports[:-num_dp_master_ports]
|
||||
|
||||
self._stateless_world_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
|
||||
]
|
||||
start_idx = num_world_groups * 3
|
||||
self._stateless_dp_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_dp_groups * 3
|
||||
self._stateless_ep_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_ep_groups * 3
|
||||
self._stateless_eplb_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
|
||||
]
|
||||
|
||||
def get_next_stateless_world_group_port(self) -> list[int]:
|
||||
return self._stateless_world_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_dp_group_port(self) -> list[int]:
|
||||
return self._stateless_dp_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_ep_group_port(self) -> list[int]:
|
||||
return self._stateless_ep_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_eplb_group_port(self) -> list[int]:
|
||||
return self._stateless_eplb_group_port_list.pop()
|
||||
|
||||
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
|
||||
# NOTE: In high-concurrency scenarios multiple processes
|
||||
# can pick the same (currently free) port through a race
|
||||
# condition when calling `get_open_port()`. When the first
|
||||
@@ -426,7 +517,8 @@ class ParallelConfig:
|
||||
self.get_next_dp_init_port(),
|
||||
self.data_parallel_rank,
|
||||
self.data_parallel_size,
|
||||
backend=current_platform.dist_backend,
|
||||
backend="gloo",
|
||||
return_store=return_store,
|
||||
)
|
||||
except DistNetworkError as e:
|
||||
# We only want to retry when the root cause is EADDRINUSE.
|
||||
@@ -561,6 +653,21 @@ class ParallelConfig:
|
||||
logger.info("Using external launcher for distributed inference.")
|
||||
self.world_size *= self.data_parallel_size
|
||||
|
||||
if self.enable_elastic_ep:
|
||||
if not self.enable_eplb:
|
||||
raise ValueError("Elastic EP is only supported with enable_eplb=True.")
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Elastic EP is not supported with pipeline parallelism "
|
||||
f"(pipeline_parallel_size={self.pipeline_parallel_size})."
|
||||
)
|
||||
if self.data_parallel_external_lb or self.data_parallel_hybrid_lb:
|
||||
raise NotImplementedError(
|
||||
"Elastic EP is not compatible with data_parallel_external_lb "
|
||||
"or data_parallel_hybrid_lb. Elastic EP relies on a single API "
|
||||
"server and core client to coordinate scale up/down."
|
||||
)
|
||||
|
||||
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
||||
# Data parallel was specified in the engine args.
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
@@ -573,9 +680,12 @@ class ParallelConfig:
|
||||
"Set data_parallel_rank to %d automatically.",
|
||||
self.data_parallel_rank,
|
||||
)
|
||||
if not self._data_parallel_master_port_list:
|
||||
self._data_parallel_master_port_list = get_open_ports_list(5)
|
||||
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
|
||||
if not self.enable_elastic_ep:
|
||||
if not self._data_parallel_master_port_list:
|
||||
self._data_parallel_master_port_list = get_open_ports_list(5)
|
||||
self.data_parallel_master_port = (
|
||||
self._data_parallel_master_port_list.pop()
|
||||
)
|
||||
|
||||
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
|
||||
raise ValueError(
|
||||
@@ -602,7 +712,7 @@ class ParallelConfig:
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
logger.info("Disabling V1 multiprocessing for external launcher.")
|
||||
|
||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||
if self.distributed_executor_backend is None and self.world_size_across_dp > 1:
|
||||
# We use multiprocessing by default if world_size fits on the
|
||||
# current node and we aren't in a ray placement group.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user