[EPLB] Optmize eplb mapping and record in router for prefill (#36261)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
Ilya Markov
2026-03-30 21:48:33 +02:00
committed by GitHub
parent 494636b29d
commit 12701e8af2
6 changed files with 338 additions and 66 deletions

View File

@@ -8,6 +8,9 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.base_router import (
eplb_map_to_physical_and_record,
)
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
@@ -55,11 +58,13 @@ def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerSta
logical_replica_count = torch.ones(
global_num_experts, dtype=torch.int64, device="cuda"
)
should_record_tensor = torch.ones((), dtype=torch.bool, device="cuda")
return EplbLayerState(
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
should_record_tensor=should_record_tensor,
)
@@ -581,3 +586,152 @@ def test_custom(
# hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# ---------------------------------------------------------------------------
# Tests for eplb_map_to_physical_and_record
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("record_enabled", [True, False])
@pytest.mark.parametrize(
"l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load",
[
pytest.param(
# logical i → physical i
[[0], [1], [2], [3]],
[1, 1, 1, 1],
4,
[[0, 1], [2, 3], [0, 2]],
[[0, 1], [2, 3], [0, 2]],
[2, 1, 2, 1],
id="identity",
),
pytest.param(
# logical 0→3, 1→0, 2→1, 3→2
[[3], [0], [1], [2]],
[1, 1, 1, 1],
4,
[[0, 1], [2, 3], [0, 2]],
[[3, 0], [1, 2], [3, 1]],
[1, 2, 1, 2],
id="shuffled",
),
pytest.param(
# logical 0→5, 1→2, 2→7, 3→0 in a larger physical space
[[5], [2], [7], [0]],
[1, 1, 1, 1],
8,
[[0, 1], [2, 3]],
[[5, 2], [7, 0]],
[1, 0, 1, 0, 0, 1, 0, 1],
id="sparse",
),
],
)
def test_eplb_map_no_redundancy(
record_enabled,
l2p_map,
replica_count,
num_physical,
topk_ids,
expected_out,
expected_load,
):
l2p = torch.tensor(l2p_map, dtype=torch.int64, device="cuda")
rc = torch.tensor(replica_count, dtype=torch.int64, device="cuda")
load = torch.zeros(num_physical, dtype=torch.int32, device="cuda")
rec = torch.tensor(record_enabled, dtype=torch.bool, device="cuda")
ids = torch.tensor(topk_ids, dtype=torch.int32, device="cuda")
out = eplb_map_to_physical_and_record(
topk_ids=ids,
expert_load_view=load,
logical_to_physical_map=l2p,
logical_replica_count=rc,
record_enabled=rec,
)
exp_out = torch.tensor(expected_out, dtype=out.dtype, device="cuda")
torch.testing.assert_close(out, exp_out)
if record_enabled:
exp_load = torch.tensor(expected_load, dtype=torch.int32, device="cuda")
torch.testing.assert_close(load, exp_load)
else:
assert load.sum().item() == 0
@pytest.mark.parametrize("record_enabled", [True, False])
@pytest.mark.parametrize(
"l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load",
[
pytest.param(
# experts 0,1 have 2 replicas; 2,3 have 1
[[0, 4], [1, 5], [2, -1], [3, -1]],
[2, 2, 1, 1],
6,
[[0, 1], [2, 3], [0, 2]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%1=0→p2,
# 3→3%1=0→p3, 4→4%2=0→p0, 5→5%1=0→p2
[[0, 5], [2, 3], [0, 2]],
[2, 0, 2, 1, 0, 1],
id="partial",
),
pytest.param(
# all 4 experts have 2 replicas
[[0, 4], [1, 5], [2, 6], [3, 7]],
[2, 2, 2, 2],
8,
[[0, 1], [2, 3], [0, 2]],
# offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%2=0→p2,
# 3→3%2=1→p7, 4→4%2=0→p0, 5→5%2=1→p6
[[0, 5], [2, 7], [0, 6]],
[2, 0, 1, 0, 0, 1, 1, 1],
id="full",
),
pytest.param(
# expert 0: 4 replicas, experts 1,2: 2 replicas
[[0, 3, 5, 7], [1, 4, -1, -1], [2, 6, -1, -1]],
[4, 2, 2],
8,
[[0, 1], [2, 0], [1, 2]],
# offs: 0→0%4=0→p0, 1→1%2=1→p4, 2→2%2=0→p2,
# 3→3%4=3→p7, 4→4%2=0→p1, 5→5%2=1→p6
[[0, 4], [2, 7], [1, 6]],
[1, 1, 1, 0, 1, 0, 1, 1],
id="uneven",
),
],
)
def test_eplb_map_with_redundancy(
record_enabled,
l2p_map,
replica_count,
num_physical,
topk_ids,
expected_out,
expected_load,
):
l2p = torch.tensor(l2p_map, dtype=torch.int64, device="cuda")
rc = torch.tensor(replica_count, dtype=torch.int64, device="cuda")
load = torch.zeros(num_physical, dtype=torch.int32, device="cuda")
rec = torch.tensor(record_enabled, dtype=torch.bool, device="cuda")
ids = torch.tensor(topk_ids, dtype=torch.int32, device="cuda")
out = eplb_map_to_physical_and_record(
topk_ids=ids,
expert_load_view=load,
logical_to_physical_map=l2p,
logical_replica_count=rc,
record_enabled=rec,
)
exp_out = torch.tensor(expected_out, dtype=out.dtype, device="cuda")
torch.testing.assert_close(out, exp_out)
if record_enabled:
exp_load = torch.tensor(expected_load, dtype=torch.int32, device="cuda")
torch.testing.assert_close(load, exp_load)
else:
assert load.sum().item() == 0