[EPLB] Simplify EPLB rearrange by only returning one map (#36267)
Signed-off-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.eplb_state import compute_logical_maps
|
||||
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
|
||||
|
||||
|
||||
@@ -24,9 +25,10 @@ def test_basic_rebalance():
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||
|
||||
# Verify output shapes
|
||||
assert phy2log.shape == (
|
||||
@@ -78,9 +80,10 @@ def test_single_gpu_case():
|
||||
num_nodes = 1
|
||||
num_gpus = 1
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 4)
|
||||
@@ -100,9 +103,10 @@ def test_equal_weights():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 8)
|
||||
@@ -123,9 +127,10 @@ def test_extreme_weight_imbalance():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (1, 12)
|
||||
@@ -151,9 +156,10 @@ def test_multiple_layers():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||
|
||||
# Verify shapes
|
||||
assert phy2log.shape == (3, 8)
|
||||
@@ -176,7 +182,8 @@ def test_parameter_validation():
|
||||
# Test non-divisible case - this should handle normally without throwing
|
||||
# errors because the function will fall back to global load balancing
|
||||
# strategy
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
|
||||
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 4)
|
||||
|
||||
@@ -198,9 +205,10 @@ def test_small_scale_hierarchical():
|
||||
num_nodes = 2 # 2 nodes
|
||||
num_gpus = 4 # 4 GPUs
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||
|
||||
# Verify basic constraints
|
||||
assert phy2log.shape == (1, 12)
|
||||
@@ -225,9 +233,10 @@ def test_global_load_balance_fallback():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||
|
||||
# Should work normally, just using global load balancing strategy
|
||||
assert phy2log.shape == (1, 8)
|
||||
@@ -247,9 +256,10 @@ def test_device_compatibility(device):
|
||||
num_nodes = 1
|
||||
num_gpus = 2
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
|
||||
|
||||
# Function will convert to CPU internally, but should handle different
|
||||
# device inputs normally
|
||||
@@ -264,9 +274,8 @@ def test_additional_cases():
|
||||
weight1 = torch.tensor(
|
||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
|
||||
)
|
||||
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
|
||||
weight1, 24, 8, 4, 8
|
||||
)
|
||||
phy2log1 = DefaultEplbPolicy.rebalance_experts(weight1, 24, 8, 4, 8)
|
||||
_, logcnt1 = compute_logical_maps(phy2log1, weight1.shape[-1])
|
||||
|
||||
assert phy2log1.shape == (1, 24)
|
||||
assert logcnt1.shape == (1, 16)
|
||||
@@ -279,9 +288,8 @@ def test_additional_cases():
|
||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||
]
|
||||
)
|
||||
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
|
||||
weight2, 10, 3, 1, 2
|
||||
)
|
||||
phy2log2 = DefaultEplbPolicy.rebalance_experts(weight2, 10, 3, 1, 2)
|
||||
_, logcnt2 = compute_logical_maps(phy2log2, weight2.shape[-1])
|
||||
|
||||
assert phy2log2.shape == (2, 10)
|
||||
assert logcnt2.shape == (2, 6)
|
||||
@@ -292,6 +300,42 @@ def test_additional_cases():
|
||||
assert logcnt2[layer, max_weight_idx] >= 2
|
||||
|
||||
|
||||
def test_compute_logical_maps_with_negative_indices():
|
||||
"""
|
||||
Test that compute_logical_maps correctly handles physical slots containing
|
||||
-1 (unused slots).
|
||||
"""
|
||||
# 2 layers, 6 physical slots, 4 logical experts.
|
||||
# Slots 2 and 5 are unused (-1).
|
||||
phy2log = torch.tensor(
|
||||
[
|
||||
[0, 1, -1, 2, 3, -1],
|
||||
[3, -1, 2, 1, 0, -1],
|
||||
]
|
||||
)
|
||||
num_layers = 2
|
||||
num_logical_experts = 4
|
||||
|
||||
log2phy, logcnt = compute_logical_maps(phy2log, num_logical_experts)
|
||||
|
||||
assert logcnt.shape == (num_layers, num_logical_experts)
|
||||
assert log2phy.shape == (num_layers, num_logical_experts, 1)
|
||||
|
||||
expected_logcnt = torch.ones(num_layers, num_logical_experts, dtype=phy2log.dtype)
|
||||
assert torch.all(logcnt == expected_logcnt), (
|
||||
f"Expected that all replica counts == 1, got {logcnt}"
|
||||
)
|
||||
|
||||
assert torch.all(log2phy >= 0), (
|
||||
"log2phy should only contain valid physical indices, not -1"
|
||||
)
|
||||
|
||||
assert log2phy[0, 0, 0] == 0
|
||||
assert log2phy[0, 1, 0] == 1
|
||||
assert log2phy[0, 2, 0] == 3
|
||||
assert log2phy[0, 3, 0] == 4
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weight = torch.tensor(
|
||||
[
|
||||
@@ -305,7 +349,7 @@ if __name__ == "__main__":
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
phy2log = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
print(phy2log)
|
||||
@@ -434,9 +478,10 @@ def test_preserve_intragpu_slots(
|
||||
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
|
||||
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
|
||||
|
||||
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
|
||||
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
|
||||
post_phy2log = DefaultEplbPolicy.preserve_intragpu_slots(
|
||||
new_phy2log, num_ranks, old_phy2log
|
||||
)
|
||||
post_phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(post_phy2log)
|
||||
|
||||
# Shapes preserved
|
||||
assert post_phy2log.shape == new_phy2log.shape
|
||||
|
||||
Reference in New Issue
Block a user