155 lines
5.5 KiB
Python
155 lines
5.5 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
|
||
|
|
from unittest.mock import MagicMock
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from vllm.distributed.eplb.eplb_state import (
|
||
|
|
_commit_eplb_maps,
|
||
|
|
_commit_eplb_maps_for_layer,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _make_model_state(
|
||
|
|
phy2log: torch.Tensor,
|
||
|
|
log2phy: torch.Tensor,
|
||
|
|
logcnt: torch.Tensor,
|
||
|
|
) -> MagicMock:
|
||
|
|
"""Build a minimal EplbModelState mock with only the three map tensors."""
|
||
|
|
state = MagicMock()
|
||
|
|
state.physical_to_logical_map = phy2log
|
||
|
|
state.logical_to_physical_map = log2phy
|
||
|
|
state.logical_replica_count = logcnt
|
||
|
|
return state
|
||
|
|
|
||
|
|
|
||
|
|
def test_commit_eplb_maps_shape_change():
|
||
|
|
"""
|
||
|
|
The normal path copies the physical_to_logical map in-place. When the number of
|
||
|
|
physical experts changes, the old map should be replaced entirely.
|
||
|
|
"""
|
||
|
|
num_layers, num_logical, num_physical = 2, 4, 6
|
||
|
|
max_replicas = 3
|
||
|
|
|
||
|
|
# Build current state tensors
|
||
|
|
model_state = _make_model_state(
|
||
|
|
phy2log=torch.zeros(num_layers, num_physical, dtype=torch.long),
|
||
|
|
log2phy=torch.full(
|
||
|
|
(num_layers, num_logical, max_replicas), -1, dtype=torch.long
|
||
|
|
),
|
||
|
|
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
|
||
|
|
)
|
||
|
|
|
||
|
|
# The new map has two more physical experts. These new physical experts will
|
||
|
|
# automatically map to the first two logical experts
|
||
|
|
new_phy2log_larger = (
|
||
|
|
(torch.arange(num_physical + 2, dtype=torch.long) % num_logical)
|
||
|
|
.unsqueeze(0)
|
||
|
|
.expand(num_layers, -1)
|
||
|
|
)
|
||
|
|
_commit_eplb_maps(model_state, new_phy2log_larger)
|
||
|
|
|
||
|
|
# Check that the number of physical experts has been updated and that the values
|
||
|
|
# match
|
||
|
|
assert model_state.physical_to_logical_map.shape[1] == num_physical + 2
|
||
|
|
assert torch.equal(model_state.physical_to_logical_map, new_phy2log_larger)
|
||
|
|
|
||
|
|
|
||
|
|
def test_commit_eplb_maps_for_layer_logical_padding():
|
||
|
|
"""
|
||
|
|
Test that logical_to_physical_map is padded with -1 to fill the
|
||
|
|
pre-allocated slots when the new map has fewer replicas than the max.
|
||
|
|
"""
|
||
|
|
num_layers, num_logical, num_physical = 2, 4, 6
|
||
|
|
max_replicas = 3
|
||
|
|
|
||
|
|
model_state = _make_model_state(
|
||
|
|
phy2log=torch.zeros(num_layers, num_physical, dtype=torch.long),
|
||
|
|
log2phy=torch.full(
|
||
|
|
(num_layers, num_logical, max_replicas), -1, dtype=torch.long
|
||
|
|
),
|
||
|
|
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
|
||
|
|
)
|
||
|
|
|
||
|
|
new_phy2log = (
|
||
|
|
(torch.arange(num_physical, dtype=torch.long) % num_logical)
|
||
|
|
.unsqueeze(0)
|
||
|
|
.expand(num_layers, -1)
|
||
|
|
.contiguous()
|
||
|
|
)
|
||
|
|
layer = 0
|
||
|
|
_commit_eplb_maps_for_layer(model_state, new_phy2log, layer)
|
||
|
|
|
||
|
|
assert torch.all(model_state.logical_to_physical_map[layer, :, 2] == -1)
|
||
|
|
|
||
|
|
|
||
|
|
def test_commit_eplb_maps_for_layer_shape_assert():
|
||
|
|
"""Test that a mismatched number of physical experts triggers an assertion error."""
|
||
|
|
num_layers, num_logical, num_physical = 2, 4, 6
|
||
|
|
|
||
|
|
model_state = _make_model_state(
|
||
|
|
phy2log=torch.zeros(num_layers, num_physical, dtype=torch.long),
|
||
|
|
log2phy=torch.full((num_layers, num_logical, 2), -1, dtype=torch.long),
|
||
|
|
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
|
||
|
|
)
|
||
|
|
bad_phy2log = torch.zeros(num_layers, num_physical + 1, dtype=torch.long)
|
||
|
|
with pytest.raises(AssertionError):
|
||
|
|
_commit_eplb_maps_for_layer(model_state, bad_phy2log, layer=0)
|
||
|
|
|
||
|
|
|
||
|
|
def test_commit_eplb_maps():
|
||
|
|
"""Test that all values are copied correctly into model_state."""
|
||
|
|
num_layers, num_logical, num_physical, max_replicas = 2, 3, 4, 2
|
||
|
|
|
||
|
|
model_state = _make_model_state(
|
||
|
|
phy2log=torch.zeros(num_layers, num_physical, dtype=torch.long),
|
||
|
|
log2phy=torch.full(
|
||
|
|
(num_layers, num_logical, max_replicas), -1, dtype=torch.long
|
||
|
|
),
|
||
|
|
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
|
||
|
|
)
|
||
|
|
|
||
|
|
new_phy2log = torch.tensor([[0, 1, 2, 0], [1, 2, 0, 1]], dtype=torch.long)
|
||
|
|
new_log2phy = torch.tensor(
|
||
|
|
[[[0, 3], [1, -1], [2, -1]], [[2, -1], [0, 3], [1, -1]]], dtype=torch.long
|
||
|
|
)
|
||
|
|
new_logcnt = torch.tensor([[2, 1, 1], [1, 2, 1]], dtype=torch.long)
|
||
|
|
|
||
|
|
_commit_eplb_maps(model_state, new_phy2log)
|
||
|
|
|
||
|
|
assert torch.equal(model_state.physical_to_logical_map, new_phy2log)
|
||
|
|
assert torch.equal(model_state.logical_to_physical_map, new_log2phy)
|
||
|
|
assert torch.equal(model_state.logical_replica_count, new_logcnt)
|
||
|
|
|
||
|
|
|
||
|
|
def test_commit_eplb_maps_for_layer():
|
||
|
|
"""Test that only the target layer is updated"""
|
||
|
|
num_layers, num_logical, max_replicas = 2, 3, 2
|
||
|
|
|
||
|
|
original_phy2log = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long)
|
||
|
|
model_state = _make_model_state(
|
||
|
|
phy2log=original_phy2log.clone(),
|
||
|
|
log2phy=torch.full(
|
||
|
|
(num_layers, num_logical, max_replicas), -1, dtype=torch.long
|
||
|
|
),
|
||
|
|
logcnt=torch.zeros(num_layers, num_logical, dtype=torch.long),
|
||
|
|
)
|
||
|
|
|
||
|
|
new_phy2log = torch.tensor([[0, 1, 2, 0], [1, 2, 0, 1]], dtype=torch.long)
|
||
|
|
new_log2phy = torch.tensor(
|
||
|
|
[[[0, 3], [1, -1], [2, -1]], [[2, -1], [0, 3], [1, -1]]], dtype=torch.long
|
||
|
|
)
|
||
|
|
new_logcnt = torch.tensor([[2, 1, 1], [1, 2, 1]], dtype=torch.long)
|
||
|
|
|
||
|
|
_commit_eplb_maps_for_layer(model_state, new_phy2log, layer=0)
|
||
|
|
|
||
|
|
# Layer 0 updated
|
||
|
|
assert torch.equal(model_state.physical_to_logical_map[0], new_phy2log[0])
|
||
|
|
assert torch.equal(model_state.logical_to_physical_map[0], new_log2phy[0])
|
||
|
|
assert torch.equal(model_state.logical_replica_count[0], new_logcnt[0])
|
||
|
|
|
||
|
|
# Layer 1 untouched
|
||
|
|
assert torch.equal(model_state.physical_to_logical_map[1], original_phy2log[1])
|