Abstract eplb algo (#26471)
Signed-off-by: Che Ruan <cr623@ic.ac.uk> Signed-off-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com> Signed-off-by: Mercykid-bash <ruanche0218@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Che Ruan <cr623@ic.ac.uk> Co-authored-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.rebalance_algo import rebalance_experts
|
||||
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
|
||||
|
||||
|
||||
def test_basic_rebalance():
|
||||
@@ -23,7 +23,7 @@ def test_basic_rebalance():
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@@ -77,7 +77,7 @@ def test_single_gpu_case():
|
||||
num_nodes = 1
|
||||
num_gpus = 1
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ def test_equal_weights():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@@ -122,7 +122,7 @@ def test_extreme_weight_imbalance():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@@ -150,7 +150,7 @@ def test_multiple_layers():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@@ -175,14 +175,14 @@ def test_parameter_validation():
|
||||
# Test non-divisible case - this should handle normally without throwing
|
||||
# errors because the function will fall back to global load balancing
|
||||
# strategy
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4)
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 4)
|
||||
|
||||
# Test cases that will actually cause errors:
|
||||
# num_physical_experts not divisible by num_gpus
|
||||
with pytest.raises(AssertionError):
|
||||
rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
|
||||
DefaultEplbPolicy.rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
|
||||
|
||||
|
||||
def test_small_scale_hierarchical():
|
||||
@@ -197,7 +197,7 @@ def test_small_scale_hierarchical():
|
||||
num_nodes = 2 # 2 nodes
|
||||
num_gpus = 4 # 4 GPUs
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@@ -224,7 +224,7 @@ def test_global_load_balance_fallback():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@@ -246,7 +246,7 @@ def test_device_compatibility(device):
|
||||
num_nodes = 1
|
||||
num_gpus = 2
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@@ -263,7 +263,9 @@ def test_additional_cases():
|
||||
weight1 = torch.tensor(
|
||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
|
||||
)
|
||||
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
|
||||
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
|
||||
weight1, 24, 8, 4, 8
|
||||
)
|
||||
|
||||
assert phy2log1.shape == (1, 24)
|
||||
assert logcnt1.shape == (1, 16)
|
||||
@@ -276,7 +278,9 @@ def test_additional_cases():
|
||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||
]
|
||||
)
|
||||
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
|
||||
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
|
||||
weight2, 10, 3, 1, 2
|
||||
)
|
||||
|
||||
assert phy2log2.shape == (2, 10)
|
||||
assert logcnt2.shape == (2, 6)
|
||||
@@ -300,7 +304,7 @@ if __name__ == "__main__":
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
print(phy2log)
|
||||
|
||||
Reference in New Issue
Block a user