Abstract eplb algo (#26471)

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com>
Signed-off-by: Mercykid-bash <ruanche0218@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Mercykid-bash
2025-12-05 03:09:09 +08:00
committed by GitHub
parent e10c84e06a
commit 1119f6e47a
8 changed files with 364 additions and 285 deletions

View File

@@ -4,7 +4,7 @@
import pytest
import torch
from vllm.distributed.eplb.rebalance_algo import rebalance_experts
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
def test_basic_rebalance():
@@ -23,7 +23,7 @@ def test_basic_rebalance():
num_nodes = 2
num_gpus = 8
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@@ -77,7 +77,7 @@ def test_single_gpu_case():
num_nodes = 1
num_gpus = 1
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@@ -99,7 +99,7 @@ def test_equal_weights():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@@ -122,7 +122,7 @@ def test_extreme_weight_imbalance():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@@ -150,7 +150,7 @@ def test_multiple_layers():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@@ -175,14 +175,14 @@ def test_parameter_validation():
# Test non-divisible case - this should handle normally without throwing
# errors because the function will fall back to global load balancing
# strategy
phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4)
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 4)
# Test cases that will actually cause errors:
# num_physical_experts not divisible by num_gpus
with pytest.raises(AssertionError):
rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
DefaultEplbPolicy.rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
def test_small_scale_hierarchical():
@@ -197,7 +197,7 @@ def test_small_scale_hierarchical():
num_nodes = 2 # 2 nodes
num_gpus = 4 # 4 GPUs
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@@ -224,7 +224,7 @@ def test_global_load_balance_fallback():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@@ -246,7 +246,7 @@ def test_device_compatibility(device):
num_nodes = 1
num_gpus = 2
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@@ -263,7 +263,9 @@ def test_additional_cases():
weight1 = torch.tensor(
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
)
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
weight1, 24, 8, 4, 8
)
assert phy2log1.shape == (1, 24)
assert logcnt1.shape == (1, 16)
@@ -276,7 +278,9 @@ def test_additional_cases():
[12, 25, 50, 100, 150, 200], # Increasing weights
]
)
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
weight2, 10, 3, 1, 2
)
assert phy2log2.shape == (2, 10)
assert logcnt2.shape == (2, 6)
@@ -300,7 +304,7 @@ if __name__ == "__main__":
num_nodes = 2
num_gpus = 8
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
print(phy2log)