[EPLB] Cleanup the transfer logic for the various eplb maps (#34520)

Signed-off-by: Sage Moore <sagmoore@redhat.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore
2026-03-27 02:18:46 -07:00
committed by GitHub
parent 6287e7fa20
commit 497e234d38
3 changed files with 247 additions and 76 deletions

View File

@@ -8,8 +8,10 @@ steps:
source_file_dependencies:
- vllm/distributed/eplb
- tests/distributed/test_eplb_algo.py
- tests/distributed/test_eplb_utils.py
commands:
- pytest -v -s distributed/test_eplb_algo.py
- pytest -v -s distributed/test_eplb_utils.py
- label: EPLB Execution
timeout_in_minutes: 20

View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
import torch
from vllm.distributed.eplb.eplb_state import (
_commit_eplb_maps,
_commit_eplb_maps_for_layer,
)
def _make_model_state(
phy2log: torch.Tensor,
log2phy: torch.Tensor,
logcnt: torch.Tensor,
) -> MagicMock:
"""Build a minimal EplbModelState mock with only the three map tensors."""
state = MagicMock()
state.physical_to_logical_map = phy2log
state.logical_to_physical_map = log2phy
state.logical_replica_count = logcnt
return state
def test_commit_eplb_maps_shape_change():
"""
The normal path copies the physical_to_logical map in-place. When the number of
physical experts changes, the old map should be replaced entirely.
"""
num_layers, num_logical, num_physical = 2, 4, 6
max_replicas = 3
# Build current state tensors
model_state = _make_model_state(
phy2log=torch.zeros(num_layers, num_physical, dtype=torch.long),
log2phy=torch.full(
(num_layers, num_logical, max_replicas), -1, dtype=torch.long
),
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
)
# The new map has two more physical experts. These new physical experts will
# automatically map to the first two logical experts
new_phy2log_larger = (
(torch.arange(num_physical + 2, dtype=torch.long) % num_logical)
.unsqueeze(0)
.expand(num_layers, -1)
)
_commit_eplb_maps(model_state, new_phy2log_larger)
# Check that the number of physical experts has been updated and that the values
# match
assert model_state.physical_to_logical_map.shape[1] == num_physical + 2
assert torch.equal(model_state.physical_to_logical_map, new_phy2log_larger)
def test_commit_eplb_maps_for_layer_logical_padding():
"""
Test that logical_to_physical_map is padded with -1 to fill the
pre-allocated slots when the new map has fewer replicas than the max.
"""
num_layers, num_logical, num_physical = 2, 4, 6
max_replicas = 3
model_state = _make_model_state(
phy2log=torch.zeros(num_layers, num_physical, dtype=torch.long),
log2phy=torch.full(
(num_layers, num_logical, max_replicas), -1, dtype=torch.long
),
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
)
new_phy2log = (
(torch.arange(num_physical, dtype=torch.long) % num_logical)
.unsqueeze(0)
.expand(num_layers, -1)
.contiguous()
)
layer = 0
_commit_eplb_maps_for_layer(model_state, new_phy2log, layer)
assert torch.all(model_state.logical_to_physical_map[layer, :, 2] == -1)
def test_commit_eplb_maps_for_layer_shape_assert():
"""Test that a mismatched number of physical experts triggers an assertion error."""
num_layers, num_logical, num_physical = 2, 4, 6
model_state = _make_model_state(
phy2log=torch.zeros(num_layers, num_physical, dtype=torch.long),
log2phy=torch.full((num_layers, num_logical, 2), -1, dtype=torch.long),
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
)
bad_phy2log = torch.zeros(num_layers, num_physical + 1, dtype=torch.long)
with pytest.raises(AssertionError):
_commit_eplb_maps_for_layer(model_state, bad_phy2log, layer=0)
def test_commit_eplb_maps():
"""Test that all values are copied correctly into model_state."""
num_layers, num_logical, num_physical, max_replicas = 2, 3, 4, 2
model_state = _make_model_state(
phy2log=torch.zeros(num_layers, num_physical, dtype=torch.long),
log2phy=torch.full(
(num_layers, num_logical, max_replicas), -1, dtype=torch.long
),
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
)
new_phy2log = torch.tensor([[0, 1, 2, 0], [1, 2, 0, 1]], dtype=torch.long)
new_log2phy = torch.tensor(
[[[0, 3], [1, -1], [2, -1]], [[2, -1], [0, 3], [1, -1]]], dtype=torch.long
)
new_logcnt = torch.tensor([[2, 1, 1], [1, 2, 1]], dtype=torch.long)
_commit_eplb_maps(model_state, new_phy2log)
assert torch.equal(model_state.physical_to_logical_map, new_phy2log)
assert torch.equal(model_state.logical_to_physical_map, new_log2phy)
assert torch.equal(model_state.logical_replica_count, new_logcnt)
def test_commit_eplb_maps_for_layer():
"""Test that only the target layer is updated"""
num_layers, num_logical, max_replicas = 2, 3, 2
original_phy2log = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long)
model_state = _make_model_state(
phy2log=original_phy2log.clone(),
log2phy=torch.full(
(num_layers, num_logical, max_replicas), -1, dtype=torch.long
),
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
)
new_phy2log = torch.tensor([[0, 1, 2, 0], [1, 2, 0, 1]], dtype=torch.long)
new_log2phy = torch.tensor(
[[[0, 3], [1, -1], [2, -1]], [[2, -1], [0, 3], [1, -1]]], dtype=torch.long
)
new_logcnt = torch.tensor([[2, 1, 1], [1, 2, 1]], dtype=torch.long)
_commit_eplb_maps_for_layer(model_state, new_phy2log, layer=0)
# Layer 0 updated
assert torch.equal(model_state.physical_to_logical_map[0], new_phy2log[0])
assert torch.equal(model_state.logical_to_physical_map[0], new_log2phy[0])
assert torch.equal(model_state.logical_replica_count[0], new_logcnt[0])
# Layer 1 untouched
assert torch.equal(model_state.physical_to_logical_map[1], original_phy2log[1])

View File

@@ -729,13 +729,6 @@ class EplbState:
eplb_model_state.physical_to_logical_map.cpu(),
)
num_logical_experts = global_expert_load_window.shape[-1]
(new_logical_to_physical_map, new_logical_replica_count) = (
compute_logical_maps(
new_physical_to_logical_map, num_logical_experts
)
)
# Update expert weights
rearrange_expert_weights_inplace(
eplb_model_state.physical_to_logical_map,
@@ -747,39 +740,11 @@ class EplbState:
)
if not is_profile:
if (
eplb_model_state.physical_to_logical_map.shape[1]
!= new_physical_to_logical_map.shape[1]
):
eplb_model_state.physical_to_logical_map = (
new_physical_to_logical_map.to(
eplb_model_state.physical_to_logical_map.device
)
)
else:
eplb_model_state.physical_to_logical_map.copy_(
new_physical_to_logical_map
)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert (
max_physical_slots
<= eplb_model_state.logical_to_physical_map.shape[-1]
)
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(
0,
eplb_model_state.logical_to_physical_map.shape[-1]
- max_physical_slots,
),
value=-1,
)
eplb_model_state.logical_to_physical_map.copy_(
new_logical_to_physical_map
)
eplb_model_state.logical_replica_count.copy_(
new_logical_replica_count
_commit_eplb_maps(
eplb_model_state,
new_physical_to_logical_map=new_physical_to_logical_map,
)
if is_main_rank:
assert start_event is not None
assert end_event is not None
@@ -829,42 +794,6 @@ class EplbState:
is_profile=is_profile,
)
def _update_layer_mapping_from_new(
self, model_state: EplbModelState, layer: int
) -> None:
if model_state.new_physical_to_logical_map is None:
return
target_device = model_state.physical_to_logical_map.device
new_physical = model_state.new_physical_to_logical_map
# If the number of physical experts has changed, then the new map needs to
# be copied synchronously to avoid a race condition with the async worker
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
model_state.physical_to_logical_map = new_physical.to(target_device)
else:
model_state.physical_to_logical_map[layer].copy_(
new_physical[layer].to(target_device, non_blocking=True)
)
num_logical_experts = model_state.logical_to_physical_map.shape[1]
new_logical, new_replica_count = compute_logical_maps(
new_physical[layer], num_logical_experts
)
logical_device = model_state.logical_to_physical_map.device
max_slots = model_state.logical_to_physical_map.shape[-1]
slot_delta = max_slots - new_logical.shape[-1]
if slot_delta > 0:
new_logical = torch.nn.functional.pad(
new_logical, (0, slot_delta), value=-1
)
model_state.logical_to_physical_map[layer].copy_(new_logical.to(logical_device))
replica_device = model_state.logical_replica_count.device
model_state.logical_replica_count[layer].copy_(
new_replica_count.to(replica_device)
)
def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
parallel_state = get_ep_group()
cpu_group = getattr(parallel_state, "cpu_group", None)
@@ -936,7 +865,12 @@ class EplbState:
model_state.buffer_consumed_event = consumed_event
transferred_layer = model_state.layer_to_transfer
self._update_layer_mapping_from_new(model_state, transferred_layer)
assert model_state.new_physical_to_logical_map is not None
_commit_eplb_maps_for_layer(
model_state,
new_physical_to_logical_map=model_state.new_physical_to_logical_map,
layer=transferred_layer,
)
# After the main thread consumes, advance layer_to_transfer
model_state.layer_to_transfer += 1
model_state.ep_buffer_ready = 0
@@ -1175,3 +1109,84 @@ def compute_logical_maps(
if per_layer:
return logical_to_physical_map_out.squeeze(0), logical_replica_count.squeeze(0)
return logical_to_physical_map_out, logical_replica_count
def _pad_out_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
src_padding = dst.shape[-1] - src.shape[-1]
assert src_padding >= 0
new_src = torch.nn.functional.pad(src, (0, src_padding), value=-1)
dst.copy_(new_src)
def _commit_eplb_maps_for_layer(
model_state: EplbModelState,
new_physical_to_logical_map: torch.Tensor,
layer: int,
) -> None:
"""
Per-layer version of _commit_eplb_maps that's used by the sync portion of EPLB
when running async EPLB. Copies all of the new_* maps into model_state. After this
function completes, the new mappings will become the current mappings and will be
visible to the model.
"""
# Commit physical_to_logical_map
src = new_physical_to_logical_map[layer]
dst = model_state.physical_to_logical_map[layer]
assert src.shape == dst.shape, (
"The number of physical experts must stay the same while running Async EPLB. "
f"Current number of physical experts: {dst.shape[0]}. New number of physical "
f"experts {src.shape[0]}."
)
dst.copy_(src, non_blocking=True)
num_logical_experts = model_state.logical_to_physical_map.shape[1]
new_logical, new_replica_count = compute_logical_maps(src, num_logical_experts)
# Commit logical_to_physical_map
_pad_out_tensor(
src=new_logical,
dst=model_state.logical_to_physical_map[layer],
)
# Commit logical_replica_count
src = new_replica_count
dst = model_state.logical_replica_count[layer]
assert src.shape == dst.shape
dst.copy_(src, non_blocking=True)
def _commit_eplb_maps(
model_state: EplbModelState,
new_physical_to_logical_map: torch.Tensor,
) -> None:
"""
Copies all of the new_* maps into model_state. After this function completes,
the new mappings will become the current mappings and will be visible to the
model.
"""
# Commit physical_to_logical_map
src = new_physical_to_logical_map
dst = model_state.physical_to_logical_map
# Rare Case: When the number of physical experts has changed, discard the old
# physical to logical expert map and use the new one. This only happens when the
# number of GPUs available to vLLM changes while vLLM is running. Otherwise copy the
# new map into the old one.
if src.shape[1] != dst.shape[1]:
model_state.physical_to_logical_map = src.to(dst.device)
else:
dst.copy_(src, non_blocking=True)
num_logical_experts = model_state.logical_to_physical_map.shape[1]
new_logical, new_replica_count = compute_logical_maps(src, num_logical_experts)
# Commit logical_to_physical_map
_pad_out_tensor(
src=new_logical,
dst=model_state.logical_to_physical_map,
)
# Commit logical_replica_count
src = new_replica_count
dst = model_state.logical_replica_count
dst.copy_(src, non_blocking=True)