[BugFix] Support EP/DP + EPLB with MTP (#25311)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Ilya Markov
2025-11-05 16:22:17 +01:00
committed by GitHub
parent 5d16d0fa62
commit e50c454672
27 changed files with 957 additions and 529 deletions

View File

@@ -33,7 +33,7 @@ from dataclasses import dataclass
import torch
from torch.distributed import ProcessGroup, all_reduce
from vllm.config import ParallelConfig
from vllm.config import ModelConfig, ParallelConfig
from vllm.distributed.parallel_state import (
get_ep_group,
get_node_count,
@@ -50,7 +50,7 @@ logger = init_logger(__name__)
@dataclass
class EplbState:
class EplbModelState:
"""EPLB metrics."""
physical_to_logical_map: torch.Tensor
@@ -130,34 +130,46 @@ class EplbState:
See:
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
"""
expert_load_window_step: int = 0
"""
Current step in the sliding window.
model_name: str
model: MixtureOfExperts
Different from `expert_rearrangement_step`, each EP rank may have its own
`expert_load_window_step`.
class EplbState:
"""
expert_load_window_size: int = 0
"""
Size of the expert load sliding window.
This is a constant and is taken from the config.
EplbState of each expert parallel model. Key is the model config hash.
"""
expert_rearrangement_step: int = 0
"""
Steps after last rearrangement.
Will trigger a rearrangement if it exceeds the threshold.
def __init__(self, parallel_config: ParallelConfig, device: torch.device):
self.parallel_config = parallel_config
self.device = device
self.model_states: dict[str, EplbModelState] = {}
"""
Current step in the sliding window.
NOTE: Keep in mind that all EP ranks need to have the same
`expert_rearrangement_step` value to ensure synchronization.
Otherwise, the rearrangement will hang at collective
communication calls.
"""
expert_rearrangement_step_interval: int = 0
"""
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
Different from `expert_rearrangement_step`,
each EP rank may have its own `expert_load_window_step`.
"""
self.expert_load_window_step: int = 0
"""
Size of the expert load sliding window.
This is a constant and is taken from the config.
"""
self.expert_load_window_size: int = 0
"""
Steps after last rearrangement.
Will trigger a rearrangement if it exceeds the threshold.
NOTE: Keep in mind that all EP ranks need to have the same
`expert_rearrangement_step` value to ensure synchronization.
Otherwise, the rearrangement will hang at collective
communication calls.
"""
self.expert_rearrangement_step: int = 0
"""
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
self.expert_rearrangement_step_interval: int = 0
@staticmethod
def build_initial_global_physical_to_logical_map(
@@ -179,26 +191,63 @@ class EplbState:
]
return global_physical_to_logical_map
@classmethod
def build(
cls,
def validate_ep_configuration(self, new_model: MixtureOfExperts):
"""
Validate that the expert parallel configuration of
the new model is the same as the existing models.
"""
if len(self.model_states) > 0:
model = next(iter(self.model_states.values())).model
if (
model.num_routed_experts != new_model.num_routed_experts
or model.num_redundant_experts != new_model.num_redundant_experts
or model.num_physical_experts != new_model.num_physical_experts
or model.num_logical_experts != new_model.num_logical_experts
or model.num_expert_groups != new_model.num_expert_groups
):
raise RuntimeError(
"Model: {} "
"with config {} "
"{} {} {} {} "
"mismatch with new model {} "
"with config {} "
"{} {} {} {}".format(
type(model),
model.num_routed_experts,
model.num_redundant_experts,
model.num_physical_experts,
model.num_logical_experts,
model.num_expert_groups,
type(new_model),
new_model.num_routed_experts,
new_model.num_redundant_experts,
new_model.num_physical_experts,
new_model.num_logical_experts,
new_model.num_expert_groups,
)
)
def add_model(
self,
model: MixtureOfExperts,
device: torch.device,
parallel_config: ParallelConfig,
model_config: ModelConfig,
global_expert_load: torch.Tensor | None = None,
old_global_expert_indices: torch.Tensor | None = None,
rank_mapping: dict[int, int] | None = None,
) -> "EplbState":
):
"""
Build the initial EPLB state.
"""
physical_to_logical_map_list = cls.build_initial_global_physical_to_logical_map(
model.num_routed_experts,
model.num_redundant_experts,
self.validate_ep_configuration(model)
physical_to_logical_map_list = (
EplbState.build_initial_global_physical_to_logical_map(
model.num_routed_experts,
model.num_redundant_experts,
)
)
physical_to_logical_map = torch.tensor(
physical_to_logical_map_list,
device=device,
device=self.device,
)
# Assuming 8 GPUs per node, this supports up to
# (1023 + 1) / 8 = 128 nodes for now.
@@ -212,11 +261,11 @@ class EplbState:
logical_to_physical_map = torch.full(
(model.num_logical_experts, max_slots_per_logical_expert),
-1,
device=device,
device=self.device,
)
logical_replica_count = torch.zeros(
(model.num_logical_experts,),
device=device,
device=self.device,
dtype=torch.long,
)
@@ -255,18 +304,25 @@ class EplbState:
expert_load_pass = torch.zeros(
(model.num_moe_layers, model.num_physical_experts),
dtype=torch.int32,
device=device,
device=self.device,
)
expert_load_window_size = parallel_config.eplb_config.window_size
self.expert_load_window_size = self.parallel_config.eplb_config.window_size
expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers, model.num_physical_experts),
(
self.expert_load_window_size,
model.num_moe_layers,
model.num_physical_experts,
),
dtype=torch.int32,
device=device,
device=self.device,
)
# Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_config.step_interval
expert_rearrangement_step = max(0, eplb_step_interval - eplb_step_interval // 4)
eplb_step_interval = self.parallel_config.eplb_config.step_interval
self.expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4
)
self.expert_rearrangement_step_interval = eplb_step_interval
if global_expert_load is not None:
ep_group = get_ep_group().device_group
@@ -309,7 +365,7 @@ class EplbState:
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(device)
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
@@ -327,22 +383,20 @@ class EplbState:
False,
rank_mapping,
)
expert_rearrangement_step = 0
self.expert_rearrangement_step = 0
return cls(
self.model_states[model_config.compute_hash()] = EplbModelState(
physical_to_logical_map,
logical_to_physical_map,
logical_replica_count,
expert_load_pass,
expert_load_window,
expert_load_window_size=expert_load_window_size,
expert_rearrangement_step=expert_rearrangement_step,
expert_rearrangement_step_interval=eplb_step_interval,
model_config.model,
model,
)
def step(
self,
model: MixtureOfExperts,
is_dummy: bool = False,
is_profile: bool = False,
log_stats: bool = False,
@@ -351,7 +405,6 @@ class EplbState:
Step the EPLB state.
Args:
model (MixtureOfExperts): The MoE model.
is_dummy (bool): If `True`, this is a dummy step and the load
metrics recorded in this forward pass will not count.
Defaults to `False`.
@@ -369,60 +422,66 @@ class EplbState:
"""
if is_profile:
self.rearrange(model, is_profile=True)
self.rearrange(is_profile=True)
return
if is_dummy:
# Do not record load metrics for dummy steps
self.expert_load_pass.zero_()
for eplb_model_state in self.model_states.values():
eplb_model_state.expert_load_pass.zero_()
if log_stats:
# total_expert_load_pass: (num_moe_layers, num_physical_experts)
total_expert_load_pass = self.expert_load_pass.clone()
# Collect load metrics from all ranks
# Sync the expert load pass for each model (main and drafter).
# expert_load_pass: (num_moe_layers, num_physical_experts)
expert_load_pass_list = self._sync_load_pass()
ep_group = get_ep_group().device_group
all_reduce(total_expert_load_pass, group=ep_group)
# num_tokens_per_rank: (num_moe_layers, num_ranks)
num_tokens_per_rank = (
total_expert_load_pass.reshape(
total_expert_load_pass.shape[0], ep_group.size(), -1
for expert_load_pass, eplb_model_state in zip(
expert_load_pass_list, self.model_states.values()
):
# num_tokens_per_rank: (num_moe_layers, num_ranks)
num_tokens_per_rank = (
expert_load_pass.reshape(
expert_load_pass.shape[0], ep_group.size(), -1
)
.sum(dim=-1)
.float()
)
.sum(dim=-1)
.float()
)
# Compute balancedness ratio:
# for each layer:
# (mean load across ranks) / (max load across ranks)
avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0)
# Compute balancedness ratio:
# for each layer:
# (mean load across ranks) / (max load across ranks)
avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0)
# Just to make type checker happy
tokens_tensors: list[float] = torch.stack(
[avg_tokens_tensor, max_tokens_tensor]
).tolist()
avg_tokens, max_tokens = tokens_tensors
balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0
# Just to make type checker happy
tokens_tensors: list[float] = torch.stack(
[avg_tokens_tensor, max_tokens_tensor]
).tolist()
avg_tokens, max_tokens = tokens_tensors
balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0
if ep_group.rank() == 0:
logger.info(
"EPLB step: avg_tokens=%.2f, max_tokens=%d, balancedness=%.4f",
avg_tokens,
max_tokens,
balancedness,
)
if ep_group.rank() == 0:
logger.info(
"EPLB step: %d for model %s: avg_tokens=%.2f, "
"max_tokens=%d, balancedness=%.4f",
self.expert_rearrangement_step,
eplb_model_state.model_name,
avg_tokens,
max_tokens,
balancedness,
)
# Update the expert load sliding window
if not is_dummy:
self.expert_load_window[self.expert_load_window_step] = (
self.expert_load_pass.clone()
)
for eplb_model_state in self.model_states.values():
eplb_model_state.expert_load_window[self.expert_load_window_step] = (
eplb_model_state.expert_load_pass.clone()
)
eplb_model_state.expert_load_pass.zero_()
self.expert_load_window_step += 1
if self.expert_load_window_step >= self.expert_load_window_size:
self.expert_load_window_step = 0
self.expert_load_pass.zero_()
# Step the expert rearrangement step
# Note that even if this is a dummy step, we still increment the
@@ -431,18 +490,30 @@ class EplbState:
self.expert_rearrangement_step += 1
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
self.expert_rearrangement_step = 0
self.rearrange(model)
self.rearrange()
def rearrange(
self,
model: MixtureOfExperts,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_load: torch.Tensor | None = None,
global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None:
"""
Rearrange the experts according to the current load.
Args:
is_profile (bool): If `True`, perform a dummy rearrangement.
This is used in `profile_run` to reserve enough memory,
no memory movement will be performed. Default is False.
execute_shuffle (bool): If `True`, execute the shuffle
in elastic expert parallel (EEP). Default is True.
global_expert_loads (list[torch.Tensor] | None): The global expert
loads when scaling is done in EEP.
List of expert loads for the main and drafter
(when spec decode is used) models.
rank_mapping (dict[int, int] | None): The rank mapping
when scaling is done in EEP.
"""
ep_group = get_ep_group().device_group
@@ -455,53 +526,71 @@ class EplbState:
time_start = time.perf_counter()
logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
if global_expert_load is None:
if global_expert_loads is None:
# Map the physical expert load to global logical experts
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
model.num_moe_layers,
model.num_logical_experts,
dtype=self.expert_load_window.dtype,
device=self.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=self.physical_to_logical_map.unsqueeze(0)
.expand_as(self.expert_load_window)
.long(),
src=self.expert_load_window,
)
global_expert_load_windows = []
if not execute_shuffle:
metadata = torch.tensor(
[
model.num_moe_layers,
model.num_logical_experts,
self.physical_to_logical_map.shape[1],
],
dtype=torch.int32,
device="cpu",
num_models = torch.tensor(
[len(self.model_states)], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
metadata, group=get_ep_group().cpu_group, group_src=0
num_models, group=get_ep_group().cpu_group, group_src=0
)
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window = logical_expert_load_window.sum(dim=0)
all_reduce(global_expert_load_window, group=ep_group)
for eplb_model_state in self.model_states.values():
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
dtype=eplb_model_state.expert_load_window.dtype,
device=eplb_model_state.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
.expand_as(eplb_model_state.expert_load_window)
.long(),
src=eplb_model_state.expert_load_window,
)
if not execute_shuffle:
metadata = torch.tensor(
[
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
eplb_model_state.physical_to_logical_map.shape[1],
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(
metadata, group=get_ep_group().cpu_group, group_src=0
)
global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(
global_expert_load_windows
)
if not execute_shuffle:
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = self.physical_to_logical_map
torch.distributed.broadcast(
old_global_expert_indices, group=ep_group, group_src=0
)
return global_expert_load_window
for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = eplb_model_state.physical_to_logical_map
torch.distributed.broadcast(
old_global_expert_indices, group=ep_group, group_src=0
)
if not execute_shuffle:
return global_expert_load_windows
else:
assert execute_shuffle
global_expert_load_window = global_expert_load
global_expert_load_windows = global_expert_loads
# TODO(bowen): Treat differently for prefill and decode nodes
eplb_model_state = next(iter(self.model_states.values()))
model = eplb_model_state.model
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
@@ -526,48 +615,64 @@ class EplbState:
f"{num_gpus=}, {num_nodes=}"
)
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = rebalance_experts(
global_expert_load_window,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
# Update expert weights
rearrange_expert_weights_inplace(
self.physical_to_logical_map,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
is_profile,
rank_mapping,
)
if not is_profile:
if (
self.physical_to_logical_map.shape[1]
!= new_physical_to_logical_map.shape[1]
):
self.physical_to_logical_map = new_physical_to_logical_map.to(
self.physical_to_logical_map.device
)
else:
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
# Get new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
(0, self.logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
new_logical_replica_count,
) = rebalance_experts(
global_expert_load_window,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
self.logical_to_physical_map.copy_(new_logical_to_physical_map)
self.logical_replica_count.copy_(new_logical_replica_count)
# Update expert weights
rearrange_expert_weights_inplace(
eplb_model_state.physical_to_logical_map,
new_physical_to_logical_map,
eplb_model_state.model.expert_weights,
ep_group,
is_profile,
rank_mapping,
)
if not is_profile:
if (
eplb_model_state.physical_to_logical_map.shape[1]
!= new_physical_to_logical_map.shape[1]
):
eplb_model_state.physical_to_logical_map = (
new_physical_to_logical_map.to(
eplb_model_state.physical_to_logical_map.device
)
)
else:
eplb_model_state.physical_to_logical_map.copy_(
new_physical_to_logical_map
)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert (
max_physical_slots
<= eplb_model_state.logical_to_physical_map.shape[-1]
)
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(
0,
eplb_model_state.logical_to_physical_map.shape[-1]
- max_physical_slots,
),
value=-1,
)
eplb_model_state.logical_to_physical_map.copy_(
new_logical_to_physical_map
)
eplb_model_state.logical_replica_count.copy_(new_logical_replica_count)
if is_main_rank:
assert time_start is not None
@@ -581,32 +686,118 @@ class EplbState:
return None
@staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist()
)
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
num_models = num_models.item()
global_expert_loads = []
old_global_expert_indices_per_model = []
for _ in range(num_models):
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist()
)
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(
old_global_expert_indices,
group=ep_group.device_group,
group_src=0,
)
global_expert_loads.append(global_expert_load)
old_global_expert_indices_per_model.append(old_global_expert_indices)
return global_expert_loads, old_global_expert_indices_per_model
@classmethod
def get_eep_state(
cls, parallel_config: ParallelConfig
) -> tuple[
list[torch.Tensor] | None,
list[torch.Tensor] | None,
dict[int, int] | None,
]:
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(
old_global_expert_indices, group=ep_group.device_group, group_src=0
num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_loads, old_global_expert_indices_per_model = (
EplbState.recv_state()
)
return global_expert_load, old_global_expert_indices
# EP configuration for all models has to be the same so as eplb config
num_logical_experts = global_expert_loads[0].shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert (
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
== 0
)
old_ep_size = (
old_global_expert_indices_per_model[0].shape[1]
// num_local_physical_experts
)
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
return (
global_expert_loads,
old_global_expert_indices_per_model,
rank_mapping,
)
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
"""
All-reduce a list of tensors.
"""
if len(tensor_list) == 1:
all_reduce(tensor_list[0], group=get_ep_group().device_group)
return tensor_list
assert all(t.dim() == 2 for t in tensor_list), "All tensors must be 2D."
assert all(t.shape[1] == tensor_list[0].shape[1] for t in tensor_list), (
"All tensors must have the same shape[1]."
)
# Concatenate, all_reduce, then unpack to original shapes.
# We assume all tensors are 2D and shape[1] (num_physical_experts)
# is the same across all models.
shapes = [t.shape for t in tensor_list]
concat_tensor = torch.cat(tensor_list, dim=0)
ep_group = get_ep_group().device_group
all_reduce(concat_tensor, group=ep_group)
all_reduce_list = []
offset = 0
for shape in shapes:
all_reduce_list.append(concat_tensor[offset : offset + shape[0], :])
offset += shape[0]
return all_reduce_list
def _sync_load_pass(self) -> list[torch.Tensor]:
"""
Sync the expert load pass across all ranks for log stats.
Doesn't update the expert load pass in eplb_model_state.
"""
load_pass_list = []
for eplb_model_state in self.model_states.values():
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
return self._allreduce_list(load_pass_list)
def _node_count_with_rank_mapping(