[EPLB] Simplify EPLB rearrange by only returning one map (#36267)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Sage Moore
2026-03-18 17:34:00 -07:00
committed by GitHub
parent ef2c4f778d
commit c32a58cc2a
5 changed files with 197 additions and 160 deletions

View File

@@ -5,6 +5,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from vllm.distributed.eplb.eplb_state import compute_logical_maps
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
@@ -24,9 +25,10 @@ def test_basic_rebalance():
num_nodes = 2 num_nodes = 2
num_gpus = 8 num_gpus = 8
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify output shapes # Verify output shapes
assert phy2log.shape == ( assert phy2log.shape == (
@@ -78,9 +80,10 @@ def test_single_gpu_case():
num_nodes = 1 num_nodes = 1
num_gpus = 1 num_gpus = 1
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify shapes # Verify shapes
assert phy2log.shape == (1, 4) assert phy2log.shape == (1, 4)
@@ -100,9 +103,10 @@ def test_equal_weights():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify shapes # Verify shapes
assert phy2log.shape == (1, 8) assert phy2log.shape == (1, 8)
@@ -123,9 +127,10 @@ def test_extreme_weight_imbalance():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify shapes # Verify shapes
assert phy2log.shape == (1, 12) assert phy2log.shape == (1, 12)
@@ -151,9 +156,10 @@ def test_multiple_layers():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify shapes # Verify shapes
assert phy2log.shape == (3, 8) assert phy2log.shape == (3, 8)
@@ -176,7 +182,8 @@ def test_parameter_validation():
# Test non-divisible case - this should handle normally without throwing # Test non-divisible case - this should handle normally without throwing
# errors because the function will fall back to global load balancing # errors because the function will fall back to global load balancing
# strategy # strategy
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4) phy2log = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
assert phy2log.shape == (1, 8) assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 4) assert logcnt.shape == (1, 4)
@@ -198,9 +205,10 @@ def test_small_scale_hierarchical():
num_nodes = 2 # 2 nodes num_nodes = 2 # 2 nodes
num_gpus = 4 # 4 GPUs num_gpus = 4 # 4 GPUs
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify basic constraints # Verify basic constraints
assert phy2log.shape == (1, 12) assert phy2log.shape == (1, 12)
@@ -225,9 +233,10 @@ def test_global_load_balance_fallback():
num_nodes = 2 num_nodes = 2
num_gpus = 4 num_gpus = 4
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Should work normally, just using global load balancing strategy # Should work normally, just using global load balancing strategy
assert phy2log.shape == (1, 8) assert phy2log.shape == (1, 8)
@@ -247,9 +256,10 @@ def test_device_compatibility(device):
num_nodes = 1 num_nodes = 1
num_gpus = 2 num_gpus = 2
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Function will convert to CPU internally, but should handle different # Function will convert to CPU internally, but should handle different
# device inputs normally # device inputs normally
@@ -264,9 +274,8 @@ def test_additional_cases():
weight1 = torch.tensor( weight1 = torch.tensor(
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]] [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
) )
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts( phy2log1 = DefaultEplbPolicy.rebalance_experts(weight1, 24, 8, 4, 8)
weight1, 24, 8, 4, 8 _, logcnt1 = compute_logical_maps(phy2log1, weight1.shape[-1])
)
assert phy2log1.shape == (1, 24) assert phy2log1.shape == (1, 24)
assert logcnt1.shape == (1, 16) assert logcnt1.shape == (1, 16)
@@ -279,9 +288,8 @@ def test_additional_cases():
[12, 25, 50, 100, 150, 200], # Increasing weights [12, 25, 50, 100, 150, 200], # Increasing weights
] ]
) )
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts( phy2log2 = DefaultEplbPolicy.rebalance_experts(weight2, 10, 3, 1, 2)
weight2, 10, 3, 1, 2 _, logcnt2 = compute_logical_maps(phy2log2, weight2.shape[-1])
)
assert phy2log2.shape == (2, 10) assert phy2log2.shape == (2, 10)
assert logcnt2.shape == (2, 6) assert logcnt2.shape == (2, 6)
@@ -292,6 +300,42 @@ def test_additional_cases():
assert logcnt2[layer, max_weight_idx] >= 2 assert logcnt2[layer, max_weight_idx] >= 2
def test_compute_logical_maps_with_negative_indices():
"""
Test that compute_logical_maps correctly handles physical slots containing
-1 (unused slots).
"""
# 2 layers, 6 physical slots, 4 logical experts.
# Slots 2 and 5 are unused (-1).
phy2log = torch.tensor(
[
[0, 1, -1, 2, 3, -1],
[3, -1, 2, 1, 0, -1],
]
)
num_layers = 2
num_logical_experts = 4
log2phy, logcnt = compute_logical_maps(phy2log, num_logical_experts)
assert logcnt.shape == (num_layers, num_logical_experts)
assert log2phy.shape == (num_layers, num_logical_experts, 1)
expected_logcnt = torch.ones(num_layers, num_logical_experts, dtype=phy2log.dtype)
assert torch.all(logcnt == expected_logcnt), (
f"Expected that all replica counts == 1, got {logcnt}"
)
assert torch.all(log2phy >= 0), (
"log2phy should only contain valid physical indices, not -1"
)
assert log2phy[0, 0, 0] == 0
assert log2phy[0, 1, 0] == 1
assert log2phy[0, 2, 0] == 3
assert log2phy[0, 3, 0] == 4
if __name__ == "__main__": if __name__ == "__main__":
weight = torch.tensor( weight = torch.tensor(
[ [
@@ -305,7 +349,7 @@ if __name__ == "__main__":
num_nodes = 2 num_nodes = 2
num_gpus = 8 num_gpus = 8
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus weight, num_replicas, num_groups, num_nodes, num_gpus
) )
print(phy2log) print(phy2log)
@@ -434,9 +478,10 @@ def test_preserve_intragpu_slots(
"""Experts that stay on a GPU keep their old slots; incoming not lost.""" """Experts that stay on a GPU keep their old slots; incoming not lost."""
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log) phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots( post_phy2log = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log new_phy2log, num_ranks, old_phy2log
) )
post_phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(post_phy2log)
# Shapes preserved # Shapes preserved
assert post_phy2log.shape == new_phy2log.shape assert post_phy2log.shape == new_phy2log.shape

View File

@@ -73,11 +73,7 @@ def run_rebalance_experts(
# Move the global expert load window to CPU for computation. # Move the global expert load window to CPU for computation.
global_expert_load_window = eplb_stats.global_expert_load_window.cpu() global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
# Compute new expert mappings for the model # Compute new expert mappings for the model
( new_physical_to_logical_map = eplb_state.policy.rebalance_experts(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = eplb_state.policy.rebalance_experts(
global_expert_load_window, global_expert_load_window,
eplb_stats.num_replicas, eplb_stats.num_replicas,
eplb_stats.num_groups, eplb_stats.num_groups,
@@ -89,16 +85,6 @@ def run_rebalance_experts(
model_state.new_physical_to_logical_map = new_physical_to_logical_map model_state.new_physical_to_logical_map = new_physical_to_logical_map
max_slots = model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
).to(model_state.logical_to_physical_map.device)
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
model_state.new_logical_to_physical_map = padded_logical
model_state.new_logical_replica_count = new_replica
async def transfer_run_periodically( async def transfer_run_periodically(
state: "EplbState", state: "EplbState",

View File

@@ -235,16 +235,6 @@ class EplbModelState:
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
the size is same as physical_to_logical_map the size is same as physical_to_logical_map
""" """
new_logical_to_physical_map: torch.Tensor | None = None
"""
intermediate variable between `move_to_buffer` and `move_to_workspace`.
the size is same as logical_to_physical_map
"""
new_logical_replica_count: torch.Tensor | None = None
"""
intermediate variable between `move_to_buffer` and `move_to_workspace`.
the size is same as logical_replica_count
"""
class EplbState: class EplbState:
@@ -508,8 +498,6 @@ class EplbState:
), ),
cuda_device_index=self.cuda_device_index, cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=None, new_physical_to_logical_map=None,
new_logical_to_physical_map=None,
new_logical_replica_count=None,
) )
self.model_states[model_config.compute_hash()] = model_state self.model_states[model_config.compute_hash()] = model_state
self.num_valid_physical_experts = model.num_physical_experts self.num_valid_physical_experts = model.num_physical_experts
@@ -738,17 +726,20 @@ class EplbState:
): ):
if not self.is_async or is_profile: if not self.is_async or is_profile:
# Get new expert mappings for the model # Get new expert mappings for the model
( new_physical_to_logical_map = self.policy.rebalance_experts(
new_physical_to_logical_map, global_expert_load_window.cpu(),
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load_window,
num_replicas, num_replicas,
num_groups, num_groups,
num_nodes, num_nodes,
num_gpus, num_gpus,
eplb_model_state.physical_to_logical_map, eplb_model_state.physical_to_logical_map.cpu(),
)
num_logical_experts = global_expert_load_window.shape[-1]
(new_logical_to_physical_map, new_logical_replica_count) = (
compute_logical_maps(
new_physical_to_logical_map, num_logical_experts
)
) )
# Update expert weights # Update expert weights
@@ -847,11 +838,7 @@ class EplbState:
def _update_layer_mapping_from_new( def _update_layer_mapping_from_new(
self, model_state: EplbModelState, layer: int self, model_state: EplbModelState, layer: int
) -> None: ) -> None:
if ( if model_state.new_physical_to_logical_map is None:
model_state.new_physical_to_logical_map is None
or model_state.new_logical_to_physical_map is None
or model_state.new_logical_replica_count is None
):
return return
target_device = model_state.physical_to_logical_map.device target_device = model_state.physical_to_logical_map.device
@@ -865,19 +852,23 @@ class EplbState:
new_physical[layer].to(target_device, non_blocking=True) new_physical[layer].to(target_device, non_blocking=True)
) )
num_logical_experts = model_state.logical_to_physical_map.shape[1]
new_logical, new_replica_count = compute_logical_maps(
new_physical[layer], num_logical_experts
)
logical_device = model_state.logical_to_physical_map.device logical_device = model_state.logical_to_physical_map.device
new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device)
max_slots = model_state.logical_to_physical_map.shape[-1] max_slots = model_state.logical_to_physical_map.shape[-1]
slot_delta = max_slots - new_logical.shape[-1] slot_delta = max_slots - new_logical.shape[-1]
if slot_delta > 0: if slot_delta > 0:
new_logical = torch.nn.functional.pad( new_logical = torch.nn.functional.pad(
new_logical, (0, slot_delta), value=-1 new_logical, (0, slot_delta), value=-1
) )
model_state.logical_to_physical_map[layer].copy_(new_logical) model_state.logical_to_physical_map[layer].copy_(new_logical.to(logical_device))
replica_device = model_state.logical_replica_count.device replica_device = model_state.logical_replica_count.device
model_state.logical_replica_count[layer].copy_( model_state.logical_replica_count[layer].copy_(
model_state.new_logical_replica_count[layer].to(replica_device) new_replica_count.to(replica_device)
) )
def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool: def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
@@ -966,7 +957,7 @@ class EplbState:
transferred_layer, transferred_layer,
) )
if model_state.layer_to_transfer >= model_state.model.num_moe_layers: if model_state.layer_to_transfer >= model_state.model.num_moe_layers:
self.post_eplb(model_state, is_profile) self.post_eplb(model_state)
model_state.rebalanced = False model_state.rebalanced = False
model_state.layer_to_transfer = 0 model_state.layer_to_transfer = 0
model_state.pending_global_ready_check = False model_state.pending_global_ready_check = False
@@ -987,14 +978,9 @@ class EplbState:
str(e), str(e),
) )
def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> None: def post_eplb(self, model_state: EplbModelState) -> None:
assert model_state.new_physical_to_logical_map is not None assert model_state.new_physical_to_logical_map is not None
assert model_state.new_logical_to_physical_map is not None
assert model_state.new_logical_replica_count is not None
model_state.new_physical_to_logical_map = None model_state.new_physical_to_logical_map = None
model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]: def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
""" """
@@ -1052,39 +1038,28 @@ class EplbState:
model_config=model_config, model_config=model_config,
) )
eplb_state.num_valid_physical_experts = num_valid_physical_experts eplb_state.num_valid_physical_experts = num_valid_physical_experts
num_moe_layers = expanded_physical_to_logical.shape[0]
num_physical_experts = expanded_physical_to_logical.shape[1]
eplb_model_state = eplb_state.model_states[model_config.compute_hash()] eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical) eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
logical_to_physical_map = torch.full( (logical_to_physical_map_cpu, logical_replica_count_cpu) = compute_logical_maps(
( expanded_physical_to_logical.cpu(), model.num_logical_experts
num_moe_layers,
model.num_logical_experts,
eplb_model_state.logical_to_physical_map.shape[2],
),
-1,
dtype=torch.int64,
) )
logical_replica_count = torch.zeros(
(num_moe_layers, model.num_logical_experts),
dtype=torch.int64,
)
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
for layer_idx in range(num_moe_layers):
for phys_idx in range(num_physical_experts):
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
if logical_idx >= 0:
replica_idx = logical_replica_count[layer_idx, logical_idx]
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
phys_idx
)
logical_replica_count[layer_idx, logical_idx] += 1
logical_to_physical_map = logical_to_physical_map.to(device) max_num_replicas = eplb_model_state.logical_to_physical_map.shape[-1]
logical_replica_count = logical_replica_count.to(device) num_replicas = logical_to_physical_map_cpu.shape[-1]
logical_to_physical_map = torch.nn.functional.pad(
logical_to_physical_map_cpu,
(
0,
max_num_replicas - num_replicas,
),
value=-1,
).to(device)
logical_replica_count = logical_replica_count_cpu.to(device)
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map) eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
eplb_model_state.logical_replica_count.copy_(logical_replica_count) eplb_model_state.logical_replica_count.copy_(logical_replica_count)
return eplb_state return eplb_state
@@ -1132,3 +1107,82 @@ def _node_count_with_rank_mapping(
node_assignment[other_rank] = next_node_id node_assignment[other_rank] = next_node_id
return next_node_id return next_node_id
def compute_logical_maps(
physical_to_logical_map: torch.Tensor,
num_logical_experts: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Derive logical_to_physical_map and logical_replica_count from
physical_to_logical_map.
Args:
physical_to_logical_map: [num_layers, num_physical_experts], logical
expert index for each physical expert slot
num_logical_experts: total number of logical experts
Returns:
logical_to_physical_map: [num_layers, num_logical_experts, max_replicas],
physical slots per logical expert; -1 where unused
logical_replica_count: [num_layers, num_logical_experts], number of
physical replicas per logical expert
"""
device = physical_to_logical_map.device
assert physical_to_logical_map.device.type == "cpu"
dtype = physical_to_logical_map.dtype
# If computing maps for a single layer, unsqueeze a single element layer dimension
per_layer = physical_to_logical_map.dim() == 1
physical_to_logical_map_view = physical_to_logical_map
if per_layer:
physical_to_logical_map_view = physical_to_logical_map.unsqueeze(0)
assert len(physical_to_logical_map_view.shape) == 2
num_layers, num_physical = physical_to_logical_map_view.shape
valid_mask = physical_to_logical_map_view >= 0
logical_replica_count = torch.zeros(
num_layers,
num_logical_experts,
dtype=dtype,
device=device,
)
logical_replica_count.scatter_add_(
1,
physical_to_logical_map_view.clamp(min=0),
valid_mask.to(dtype),
)
max_replicas = int(logical_replica_count.max().item())
logical_to_physical_map_out = torch.full(
(num_layers, num_logical_experts, max_replicas),
-1,
dtype=dtype,
device=device,
)
running_count = torch.zeros_like(logical_replica_count)
layer_indices = torch.arange(num_layers, device=device)
for phys_idx in range(num_physical):
# Logical expert at physical slot phys_idx for each layer
logical_expert_ids = physical_to_logical_map_view[:, phys_idx] # [num_layers]
# Scale up will set the logical expert ids to -1 for all new physical experts.
# Only consider "valid" experts when setting up the logical_to_physical map.
valid_expert_mask = logical_expert_ids >= 0
if not valid_expert_mask.any():
continue
valid_layers = layer_indices[valid_expert_mask]
valid_experts = logical_expert_ids[valid_expert_mask]
# Use the current running count as the replica index, then increment it.
replica_idx = running_count[valid_layers, valid_experts]
logical_to_physical_map_out[valid_layers, valid_experts, replica_idx] = phys_idx
running_count[valid_layers, valid_experts] += 1
# If computing maps for a single layer, squeeze out the extra layer dimension
# before returning
if per_layer:
return logical_to_physical_map_out.squeeze(0), logical_replica_count.squeeze(0)
return logical_to_physical_map_out, logical_replica_count

View File

@@ -17,7 +17,7 @@ class AbstractEplbPolicy(ABC):
num_nodes: int, num_nodes: int,
num_ranks: int, num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None, old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
""" """
Entry point for expert-parallelism load balancer. Entry point for expert-parallelism load balancer.
@@ -35,9 +35,5 @@ class AbstractEplbPolicy(ABC):
Returns: Returns:
physical_to_logical_map: [layers, num_replicas], the expert physical_to_logical_map: [layers, num_replicas], the expert
index of each replica index of each replica
logical_to_physical_map: [layers, num_logical_experts, X],
the replica indices for each expert
expert_count: [layers, num_logical_experts], number of
physical replicas for each logical expert
""" """
raise NotImplementedError raise NotImplementedError

View File

@@ -75,7 +75,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
@classmethod @classmethod
def replicate_experts( def replicate_experts(
cls, weight: np.ndarray, num_phy: int cls, weight: np.ndarray, num_phy: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
Replicate `num_log` experts to `num_phy` replicas, such that the maximum Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized. load of all replicas is minimized.
@@ -86,22 +86,19 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns: Returns:
phy2log: [X, num_phy], logical expert id of each physical expert phy2log: [X, num_phy], logical expert id of each physical expert
replica_idx: [X, num_phy], the index of the replica for each logical expert
logcnt: [X, num_log], number of replicas for each logical expert logcnt: [X, num_log], number of replicas for each logical expert
""" """
n, num_log = weight.shape n, num_log = weight.shape
num_redundant = num_phy - num_log num_redundant = num_phy - num_log
assert num_redundant >= 0 assert num_redundant >= 0
phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1)) phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
replica_idx = np.zeros((n, num_phy), dtype=np.int64)
logcnt = np.ones((n, num_log), dtype=np.int64) logcnt = np.ones((n, num_log), dtype=np.int64)
arangen = np.arange(n, dtype=np.int64) arangen = np.arange(n, dtype=np.int64)
for i in range(num_log, num_phy): for i in range(num_log, num_phy):
redundant_indices = np.argmax(weight / logcnt, axis=-1) redundant_indices = np.argmax(weight / logcnt, axis=-1)
phy2log[:, i] = redundant_indices phy2log[:, i] = redundant_indices
replica_idx[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1 logcnt[arangen, redundant_indices] += 1
return phy2log, replica_idx, logcnt return phy2log, logcnt
@classmethod @classmethod
def rebalance_experts_hierarchical( def rebalance_experts_hierarchical(
@@ -111,7 +108,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
num_groups: int, num_groups: int,
num_nodes: int, num_nodes: int,
num_gpus: int, num_gpus: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> np.ndarray:
""" """
Parameters: Parameters:
weight: [num_moe_layers, num_logical_experts] weight: [num_moe_layers, num_logical_experts]
@@ -124,10 +121,6 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns: Returns:
phy2log: [layers, num_replicas], the expert phy2log: [layers, num_replicas], the expert
index of each replica index of each replica
pphy_replicas_idx: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
""" """
num_layers, num_logical_experts = weight.shape num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0 assert num_logical_experts % num_groups == 0
@@ -167,7 +160,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape( tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
-1, num_logical_experts // num_nodes -1, num_logical_experts // num_nodes
) )
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts( phy2mlog, mlogcnt = cls.replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes tokens_per_mlog, num_physical_experts // num_nodes
) )
@@ -193,22 +186,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
).reshape(num_layers, -1) ).reshape(num_layers, -1)
# Map node-local logical indices back to global logical expert ids. # Map node-local logical indices back to global logical expert ids.
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1) pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
# Reorder replica ranks to the post-packing physical ordering. return pphy2log
pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape(
num_layers, -1
)
# Convert replica counts back to the original logical ordering.
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1)
return pphy2log, pphy_replicas_idx, logcnt
@classmethod @classmethod
def preserve_intragpu_slots( def preserve_intragpu_slots(
cls, cls,
phy2log: np.ndarray, phy2log: np.ndarray,
phy_replicas_idx: np.ndarray,
num_ranks: int, num_ranks: int,
old_phy2log: np.ndarray, old_phy2log: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]: ) -> np.ndarray:
""" """
Reorder the new mapping per GPU so that experts that remain on the same GPU Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU keep their previous slot positions when possible. Incoming experts to that GPU
@@ -218,14 +204,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
""" """
num_phy_experts = phy2log.shape[1] num_phy_experts = phy2log.shape[1]
if num_ranks <= 0 or num_phy_experts % num_ranks != 0: if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
return phy2log, phy_replicas_idx return phy2log
# Move to CPU and convert to NumPy for processing # Move to CPU and convert to NumPy for processing
slots_per_gpu = num_phy_experts // num_ranks slots_per_gpu = num_phy_experts // num_ranks
num_layers = phy2log.shape[0] num_layers = phy2log.shape[0]
post_phy2log = phy2log.copy() post_phy2log = phy2log.copy()
post_phy_replicas_idx = phy_replicas_idx.copy()
for gpu_idx in range(num_ranks): for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu start = gpu_idx * slots_per_gpu
@@ -233,7 +218,6 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
# Experts across all layers for this GPU # Experts across all layers for this GPU
old_local = old_phy2log[:, start:end] # [layers, slots] old_local = old_phy2log[:, start:end] # [layers, slots]
new_local = phy2log[:, start:end] # [layers, slots] new_local = phy2log[:, start:end] # [layers, slots]
new_ridx = phy_replicas_idx[:, start:end] # [layers, slots]
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool) used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool) preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
@@ -253,9 +237,6 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
post_phy2log[layer_indices, start + slot_idx] = new_local[ post_phy2log[layer_indices, start + slot_idx] = new_local[
layer_indices, matched_new_positions layer_indices, matched_new_positions
] ]
post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[
layer_indices, matched_new_positions
]
used_new_indices[layer_indices, matched_new_positions] = True used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, slot_idx] = True preserved_positions[layer_indices, slot_idx] = True
@@ -287,11 +268,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
post_phy2log[layer_idx, start + dst_pos] = new_local[ post_phy2log[layer_idx, start + dst_pos] = new_local[
layer_idx, src_pos layer_idx, src_pos
] ]
post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[
layer_idx, src_pos
]
return post_phy2log, post_phy_replicas_idx return post_phy2log
@classmethod @classmethod
def rebalance_experts( def rebalance_experts(
@@ -302,7 +280,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
num_nodes: int, num_nodes: int,
num_ranks: int, num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None, old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
""" """
Entry point for expert-parallelism load balancer. Entry point for expert-parallelism load balancer.
@@ -321,13 +299,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns: Returns:
phy2log: [layers, num_replicas], the expert phy2log: [layers, num_replicas], the expert
index of each replica index of each replica
log2phy: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
""" """
device = weight.device
num_layers, num_logical_experts = weight.shape
weight_np = weight.float().cpu().numpy() weight_np = weight.float().cpu().numpy()
old_phy2log_np = ( old_phy2log_np = (
old_global_expert_indices.cpu().numpy() old_global_expert_indices.cpu().numpy()
@@ -337,17 +309,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
if num_groups % num_nodes == 0: if num_groups % num_nodes == 0:
# use hierarchical load-balance policy # use hierarchical load-balance policy
phy2log_np, phy_replicas_idx_np, logcnt_np = ( phy2log_np = cls.rebalance_experts_hierarchical(
cls.rebalance_experts_hierarchical( weight_np, num_replicas, num_groups, num_nodes, num_ranks
weight_np, num_replicas, num_groups, num_nodes, num_ranks
)
) )
else: else:
# use global load-balance policy # use global load-balance policy
phy2log_np, phy_replicas_idx_np, logcnt_np = ( phy2log_np = cls.rebalance_experts_hierarchical(
cls.rebalance_experts_hierarchical( weight_np, num_replicas, 1, 1, num_ranks
weight_np, num_replicas, 1, 1, num_ranks
)
) )
# Optional postprocessing to preserve slots for experts moving # Optional postprocessing to preserve slots for experts moving
@@ -355,22 +323,10 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
# Only apply when the number of GPUs and slots per GPU remain unchanged. # Only apply when the number of GPUs and slots per GPU remain unchanged.
# Helps to avoid unnecessary weight copying when experts move # Helps to avoid unnecessary weight copying when experts move
# within the same GPU. # within the same GPU.
if old_global_expert_indices is not None: if old_phy2log_np is not None:
phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots( phy2log_np = cls.preserve_intragpu_slots(
phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np phy2log_np, num_ranks, old_phy2log_np
) )
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy_np = np.full(
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64
)
layer_indices = np.arange(num_layers)[:, None]
replica_indices = np.tile(
np.arange(num_replicas, dtype=np.int64), (num_layers, 1)
)
log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices
phy2log = torch.from_numpy(phy2log_np).to(device) phy2log = torch.from_numpy(phy2log_np)
log2phy = torch.from_numpy(log2phy_np).to(device) return phy2log
logcnt = torch.from_numpy(logcnt_np).to(device)
return phy2log, log2phy, logcnt