Elastic Expert Parallel Initial Support (#20775)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao
2025-07-18 17:46:09 -07:00
committed by GitHub
parent 5782581acf
commit 217937221b
24 changed files with 1659 additions and 68 deletions

View File

@@ -29,12 +29,15 @@ physical experts.
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch.distributed import all_gather, all_reduce
from torch.distributed import ProcessGroup, all_gather, all_reduce
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
in_the_same_node_as)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
@@ -172,6 +175,9 @@ class EplbState:
model: MixtureOfExperts,
device: torch.device,
parallel_config: ParallelConfig,
global_expert_load: Optional[torch.Tensor] = None,
old_global_expert_indices: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int, int]] = None,
) -> "EplbState":
"""
Build the initial EPLB state.
@@ -185,8 +191,16 @@ class EplbState:
physical_to_logical_map_list,
device=device,
)
# Assuming 8 GPUs per node, this supports up to
# (1023 + 1) / 8 = 128 nodes for now.
# TODO(rui): make this configurable
MAX_EXPERT_REDUNDANCY = 1023
assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
f"num_redundant_experts {model.num_redundant_experts} "
f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}")
max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
logical_to_physical_map = torch.full(
(model.num_logical_experts, model.num_redundant_experts + 1),
(model.num_logical_experts, max_slots_per_logical_expert),
-1,
device=device,
)
@@ -235,11 +249,63 @@ class EplbState:
expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (model.num_moe_layers,
model.num_logical_experts)
assert global_expert_load.dtype == torch.int64
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}")
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = (rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
num_nodes,
num_gpus,
))
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
logical_replica_count,
)
if global_expert_load is not None:
rearrange_expert_weights_inplace(
old_global_expert_indices,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
False,
rank_mapping,
)
expert_rearrangement_step = 0
return cls(
physical_to_logical_map,
@@ -337,7 +403,10 @@ class EplbState:
def rearrange(self,
model: MixtureOfExperts,
is_profile: bool = False) -> None:
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int, int]] = None) -> None:
"""
Rearrange the experts according to the current load.
"""
@@ -353,42 +422,79 @@ class EplbState:
logger.info("Rearranging experts %s...",
"(profile)" if is_profile else "")
# This mapping is only used here, so we do not store it in the state
physical_expert_start = ep_rank * model.num_local_physical_experts
physical_expert_end = (physical_expert_start +
model.num_local_physical_experts)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map = self.physical_to_logical_map[
:,
physical_expert_start:physical_expert_end,
]
if global_expert_load is None:
# This mapping is only used here, so we do not store it in the state
physical_expert_start = ep_rank * model.num_local_physical_experts
physical_expert_end = (physical_expert_start +
model.num_local_physical_experts)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map = self.physical_to_logical_map[
:,
physical_expert_start:physical_expert_end,
]
# Map the local physical expert load to global logical experts
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
model.num_moe_layers,
model.num_logical_experts,
dtype=self.expert_load_window.dtype,
device=self.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(),
src=self.expert_load_window,
)
# Map the local physical expert load to global logical experts
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
model.num_moe_layers,
model.num_logical_experts,
dtype=self.expert_load_window.dtype,
device=self.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(),
src=self.expert_load_window,
)
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window = logical_expert_load_window.sum(dim=0)
all_reduce(global_expert_load_window, group=ep_group)
if not execute_shuffle:
metadata = torch.tensor(
[
model.num_moe_layers, model.num_logical_experts,
self.physical_to_logical_map.shape[1]
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(metadata,
group=get_ep_group().cpu_group,
group_src=0)
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window = logical_expert_load_window.sum(dim=0)
all_reduce(global_expert_load_window, group=ep_group)
if not execute_shuffle:
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = self.physical_to_logical_map
torch.distributed.broadcast(old_global_expert_indices,
group=ep_group,
group_src=0)
return global_expert_load_window
else:
assert execute_shuffle
global_expert_load_window = global_expert_load
# TODO(bowen): Treat differently for prefill and decode nodes
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released.
cpu_group = get_ep_group().cpu_group
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
num_gpus = sum(new_rank != -1
for new_rank in rank_mapping.values())
num_replicas = num_replicas // ep_group.size(
) * num_gpus # handle num replicas change
else:
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
self.num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
@@ -414,10 +520,24 @@ class EplbState:
model.expert_weights,
ep_group,
is_profile,
rank_mapping,
)
if not is_profile:
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
if self.physical_to_logical_map.shape[
1] != new_physical_to_logical_map.shape[1]:
self.physical_to_logical_map = new_physical_to_logical_map.to(
self.physical_to_logical_map.device)
else:
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0,
self.logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
self.logical_to_physical_map.copy_(new_logical_to_physical_map)
self.logical_replica_count.copy_(new_logical_replica_count)
@@ -430,3 +550,69 @@ class EplbState:
" (profile) " if is_profile else " ",
time_end - time_start,
)
@staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata,
group=ep_group.cpu_group,
group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist())
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(old_global_expert_indices,
group=ep_group.device_group,
group_src=0)
return global_expert_load, old_global_expert_indices
def _node_count_with_rank_mapping(
pg: Union[ProcessGroup, StatelessProcessGroup],
rank_mapping: dict[int, int],
) -> int:
if isinstance(pg, ProcessGroup):
world_size = torch.distributed.get_world_size(group=pg)
else:
world_size = pg.world_size
if world_size == 1:
return 1
# Build node assignment map
node_assignment = [0] * world_size # rank -> node_id
next_node_id = 0
for current_rank in range(world_size):
if node_assignment[current_rank] != 0:
continue # Already assigned to a node
assert current_rank in rank_mapping
if rank_mapping[current_rank] == -1:
continue # Pending shutdown
# Assign current rank to a new node
next_node_id += 1
node_assignment[current_rank] = next_node_id
# Find all ranks on the same node as current_rank
same_node_flags = in_the_same_node_as(pg, current_rank)
for other_rank, is_same_node in enumerate(same_node_flags):
if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id
return next_node_id

View File

@@ -8,6 +8,7 @@ This involves the exchange of expert weights between GPUs.
from collections.abc import Iterable, MutableSequence, Sequence
from functools import partial
from typing import Optional
import torch
from torch.distributed import (P2POp, ProcessGroup, all_gather,
@@ -127,6 +128,8 @@ def shuffle_layer(
dst_global = local2global(dst)
if is_received_locally[dst]:
continue
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
continue
if old_indices[src_global] == new_indices[dst_global]:
is_received_locally[dst] = True
for weight, buffer in zip(expert_weights,
@@ -139,6 +142,8 @@ def shuffle_layer(
experts_send_loc: dict[int, int] = {}
for src in range(num_local_experts):
expert = old_indices[local2global(src)]
if expert == -1:
continue
if expert in experts_send_loc:
continue
experts_send_loc[expert] = src
@@ -181,6 +186,8 @@ def shuffle_layer(
if is_received_locally[dst]:
continue
expert = new_indices[local2global(dst)]
if expert == -1:
continue
if expert in experts_recv_loc:
continue
experts_recv_loc[expert] = dst
@@ -227,6 +234,8 @@ def shuffle_layer(
weight[dst].copy_(buffer[dst])
else:
expert = new_indices[local2global(dst)]
if expert == -1:
continue
src = experts_recv_loc[expert]
for weight, buffer in zip(expert_weights, expert_weights_buffer):
weight[dst].copy_(buffer[src])
@@ -238,6 +247,7 @@ def rearrange_expert_weights_inplace(
expert_weights: Sequence[Iterable[torch.Tensor]],
ep_group: ProcessGroup,
is_profile: bool = False,
rank_mapping: Optional[dict[int, int]] = None,
) -> None:
"""
Rearranges the expert weights in place according to the new expert indices.
@@ -256,7 +266,28 @@ def rearrange_expert_weights_inplace(
is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers.
rank_mapping: A dictionary mapping old rank to new rank.
"""
if rank_mapping is not None:
if len(rank_mapping) == ep_group.size():
# scale down
new_global_expert_indices = \
_map_new_expert_indices_with_rank_mapping(
new_global_expert_indices,
rank_mapping,
)
else:
# scale up
old_global_expert_indices = \
_map_old_expert_indices_with_rank_mapping(
old_global_expert_indices,
rank_mapping,
ep_group.size(),
)
assert old_global_expert_indices.shape[
1] == new_global_expert_indices.shape[1]
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
assert len(expert_weights) == num_moe_layers
@@ -304,4 +335,90 @@ def rearrange_expert_weights_inplace(
)
def _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
new_ep_size: int,
) -> torch.Tensor:
"""
Map the old global expert indices to the new global expert indices.
Args:
old_global_expert_indices:
Shape (num_layers, old_ep_size * num_local_physical_experts).
rank_mapping: Mapping from old rank to new rank.
new_ep_size: New expert parallelism size.
Returns:
Mapped expert indices with shape
(num_layers, new_ep_size * num_local_physical_experts).
"""
num_layers, old_num_physical_experts = old_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
num_local_physical_experts = old_num_physical_experts // old_ep_size
new_num_physical_experts = new_ep_size * num_local_physical_experts
# Create mapped tensor with new shape, initialized to -1
mapped_expert_indices = torch.full(
(num_layers, new_num_physical_experts),
fill_value=-1,
dtype=old_global_expert_indices.dtype,
device=old_global_expert_indices.device,
)
# Handle rank mapping (scale up/down with rank changes)
for old_rank in range(old_ep_size):
new_rank = rank_mapping.get(old_rank)
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
# This old rank exists in the new configuration
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, new_start_idx:new_end_idx] = \
old_global_expert_indices[:, old_start_idx:old_end_idx]
# If new_rank is None or >= new_ep_size, the experts remain -1
# (scale down case)
return mapped_expert_indices
def _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices: torch.Tensor,
rank_mapping: dict[int, int],
) -> torch.Tensor:
num_layers, new_num_physical_experts = new_global_expert_indices.shape
assert rank_mapping, "Rank mapping is required"
# Get sizes from parameters and rank_mapping
old_ep_size = len(rank_mapping)
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_local_physical_experts = new_num_physical_experts // new_ep_size
old_num_physical_experts = old_ep_size * num_local_physical_experts
mapped_expert_indices = torch.full(
(num_layers, old_num_physical_experts),
fill_value=-1,
dtype=new_global_expert_indices.dtype,
device=new_global_expert_indices.device,
)
for old_rank in range(old_ep_size):
new_rank = rank_mapping[old_rank]
if new_rank >= 0 and new_rank < new_ep_size:
old_start_idx = old_rank * num_local_physical_experts
old_end_idx = (old_rank + 1) * num_local_physical_experts
new_start_idx = new_rank * num_local_physical_experts
new_end_idx = (new_rank + 1) * num_local_physical_experts
mapped_expert_indices[:, old_start_idx:old_end_idx] = \
new_global_expert_indices[:, new_start_idx:new_end_idx]
return mapped_expert_indices
__all__ = ["rearrange_expert_weights_inplace"]