[EPLB] Simplify EPLB rearrange by only returning one map (#36267)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Sage Moore
2026-03-18 17:34:00 -07:00
committed by GitHub
parent ef2c4f778d
commit c32a58cc2a
5 changed files with 197 additions and 160 deletions

View File

@@ -5,6 +5,7 @@ import numpy as np
import pytest
import torch
from vllm.distributed.eplb.eplb_state import compute_logical_maps
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
@@ -24,9 +25,10 @@ def test_basic_rebalance():
num_nodes = 2
num_gpus = 8
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify output shapes
assert phy2log.shape == (
@@ -78,9 +80,10 @@ def test_single_gpu_case():
num_nodes = 1
num_gpus = 1
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify shapes
assert phy2log.shape == (1, 4)
@@ -100,9 +103,10 @@ def test_equal_weights():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify shapes
assert phy2log.shape == (1, 8)
@@ -123,9 +127,10 @@ def test_extreme_weight_imbalance():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify shapes
assert phy2log.shape == (1, 12)
@@ -151,9 +156,10 @@ def test_multiple_layers():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify shapes
assert phy2log.shape == (3, 8)
@@ -176,7 +182,8 @@ def test_parameter_validation():
# Test non-divisible case - this should handle normally without throwing
# errors because the function will fall back to global load balancing
# strategy
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
phy2log = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 4)
@@ -198,9 +205,10 @@ def test_small_scale_hierarchical():
num_nodes = 2 # 2 nodes
num_gpus = 4 # 4 GPUs
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Verify basic constraints
assert phy2log.shape == (1, 12)
@@ -225,9 +233,10 @@ def test_global_load_balance_fallback():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Should work normally, just using global load balancing strategy
assert phy2log.shape == (1, 8)
@@ -247,9 +256,10 @@ def test_device_compatibility(device):
num_nodes = 1
num_gpus = 2
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
# Function will convert to CPU internally, but should handle different
# device inputs normally
@@ -264,9 +274,8 @@ def test_additional_cases():
weight1 = torch.tensor(
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
)
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
weight1, 24, 8, 4, 8
)
phy2log1 = DefaultEplbPolicy.rebalance_experts(weight1, 24, 8, 4, 8)
_, logcnt1 = compute_logical_maps(phy2log1, weight1.shape[-1])
assert phy2log1.shape == (1, 24)
assert logcnt1.shape == (1, 16)
@@ -279,9 +288,8 @@ def test_additional_cases():
[12, 25, 50, 100, 150, 200], # Increasing weights
]
)
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
weight2, 10, 3, 1, 2
)
phy2log2 = DefaultEplbPolicy.rebalance_experts(weight2, 10, 3, 1, 2)
_, logcnt2 = compute_logical_maps(phy2log2, weight2.shape[-1])
assert phy2log2.shape == (2, 10)
assert logcnt2.shape == (2, 6)
@@ -292,6 +300,42 @@ def test_additional_cases():
assert logcnt2[layer, max_weight_idx] >= 2
def test_compute_logical_maps_with_negative_indices():
"""
Test that compute_logical_maps correctly handles physical slots containing
-1 (unused slots).
"""
# 2 layers, 6 physical slots, 4 logical experts.
# Slots 2 and 5 are unused (-1).
phy2log = torch.tensor(
[
[0, 1, -1, 2, 3, -1],
[3, -1, 2, 1, 0, -1],
]
)
num_layers = 2
num_logical_experts = 4
log2phy, logcnt = compute_logical_maps(phy2log, num_logical_experts)
assert logcnt.shape == (num_layers, num_logical_experts)
assert log2phy.shape == (num_layers, num_logical_experts, 1)
expected_logcnt = torch.ones(num_layers, num_logical_experts, dtype=phy2log.dtype)
assert torch.all(logcnt == expected_logcnt), (
f"Expected that all replica counts == 1, got {logcnt}"
)
assert torch.all(log2phy >= 0), (
"log2phy should only contain valid physical indices, not -1"
)
assert log2phy[0, 0, 0] == 0
assert log2phy[0, 1, 0] == 1
assert log2phy[0, 2, 0] == 3
assert log2phy[0, 3, 0] == 4
if __name__ == "__main__":
weight = torch.tensor(
[
@@ -305,7 +349,7 @@ if __name__ == "__main__":
num_nodes = 2
num_gpus = 8
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
print(phy2log)
@@ -434,9 +478,10 @@ def test_preserve_intragpu_slots(
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
post_phy2log = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, num_ranks, old_phy2log
)
post_phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(post_phy2log)
# Shapes preserved
assert post_phy2log.shape == new_phy2log.shape