[EPLB] Add alternative communication for EPLB weight exchange (#33176)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Markov Ilya <markovilya19@gmail.com>
Co-authored-by: Markov Ilya <markovilya19@gmail.com>
This commit is contained in:
Ilya Markov
2026-03-31 14:17:12 +02:00
committed by GitHub
parent 0c63739135
commit abdbb68386
12 changed files with 635 additions and 183 deletions

View File

@@ -9,6 +9,7 @@ import torch
import torch.distributed
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.rebalance_execute import (
move_from_buffer,
rearrange_expert_weights_inplace,
@@ -130,9 +131,10 @@ def verify_expert_weights_after_shuffle(
hidden_sizes: list[int],
ep_rank: int,
num_local_experts: int,
):
) -> bool:
"""Verify the weights after shuffling are correct."""
num_layers = len(expert_weights)
ok = True
for layer in range(num_layers):
for weight_idx, hidden_size in enumerate(hidden_sizes):
@@ -155,29 +157,38 @@ def verify_expert_weights_after_shuffle(
dtype=actual_weights.dtype,
)
torch.testing.assert_close(
actual_weights,
expected_weights,
msg=f"Layer {layer}, weight {weight_idx},"
f"local expert {local_expert}: "
f"weights do not match. "
f"Expected logical expert {expected_logical_expert}",
)
if not torch.equal(actual_weights, expected_weights):
ok = False
actual_head = actual_weights[:8].detach().cpu().tolist()
expected_head = expected_weights[:8].detach().cpu().tolist()
print(
"verify_expert_weights_after_shuffle failed: "
f"rank={ep_rank}, "
f"layer={layer}, weight_idx={weight_idx}, "
f"local_expert={local_expert}, "
f"expected_logical_expert={expected_logical_expert}, "
f"actual_head={actual_head}, expected_head={expected_head}",
flush=True,
)
return ok
def verify_redundant_experts_have_same_weights(
expert_weights: list[list[torch.Tensor]],
indices: torch.Tensor,
hidden_sizes: list[int],
ep_rank: int,
world_size: int,
num_local_experts: int,
):
) -> bool:
"""
Verify that all replicas of the same logical expert have the same weights.
"""
num_layers = len(expert_weights)
total_physical_experts = world_size * num_local_experts
ok = True
for layer in range(num_layers):
# Collect weights for all physical experts for each weight matrix
all_weights: list[torch.Tensor] = []
@@ -227,14 +238,54 @@ def verify_redundant_experts_have_same_weights(
# Verify that current physical expert's weights match the
# previously saved logical expert weights
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
if not torch.equal(
all_weights[weight_idx][physical_pos],
logical_expert_weights[logical_expert_id][weight_idx],
msg=f"Layer {layer}, weight {weight_idx},"
f"logical expert {logical_expert_id}: "
f"Physical expert {physical_pos} has different weights"
f"than expected",
)
):
ok = False
actual_head = (
all_weights[weight_idx][physical_pos][:8]
.detach()
.cpu()
.tolist()
)
reference_head = (
logical_expert_weights[logical_expert_id][weight_idx][:8]
.detach()
.cpu()
.tolist()
)
print(
"verify_redundant_experts_have_same_weights failed: "
f"rank={ep_rank}, "
f"layer={layer}, weight_idx={weight_idx}, "
f"logical_expert={logical_expert_id}, "
f"physical_pos={physical_pos}, "
f"actual_head={actual_head}, "
f"reference_head={reference_head}",
flush=True,
)
return ok
def assert_verification_synced(local_ok: bool, msg: str) -> None:
ok_tensor = torch.tensor([1 if local_ok else 0], device="cuda", dtype=torch.int32)
torch.distributed.all_reduce(ok_tensor, op=torch.distributed.ReduceOp.MIN)
assert bool(ok_tensor.item()), msg
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
try:
return create_eplb_communicator(
group_coordinator=group_coordinator,
backend=backend,
expert_weights=expert_weights,
)
except Exception as exc:
raise RuntimeError(
f"Failed to create EPLB communicator for backend={backend}: {exc}"
) from exc
def _test_async_transfer_layer_without_mtp_worker(
@@ -243,6 +294,7 @@ def _test_async_transfer_layer_without_mtp_worker(
num_layers: int,
num_local_experts: int,
num_logical_experts: int,
eplb_communicator: str,
) -> None:
set_env_vars_and_device(env)
@@ -254,8 +306,8 @@ def _test_async_transfer_layer_without_mtp_worker(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
tp_group = get_tp_group()
ep_group = tp_group.device_group
ep_group_coordinator = get_tp_group()
ep_group = ep_group_coordinator.device_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
@@ -298,6 +350,13 @@ def _test_async_transfer_layer_without_mtp_worker(
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device)
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend=eplb_communicator,
expert_weights=expert_weights[0],
)
communicator.set_stream(cuda_stream)
for layer_idx in range(num_layers):
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer(
@@ -306,6 +365,7 @@ def _test_async_transfer_layer_without_mtp_worker(
expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer,
ep_group=ep_group,
communicator=communicator,
cuda_stream=cuda_stream,
)
)
@@ -320,24 +380,38 @@ def _test_async_transfer_layer_without_mtp_worker(
ep_rank=ep_rank,
)
verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
local_ok = verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
local_ok = (
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
world_size,
num_local_experts,
)
and local_ok
)
assert_verification_synced(
local_ok,
"Async transfer verification failed on at least one rank. "
"See logs for details.",
)
def _test_rearrange_expert_weights_with_redundancy(
env, world_size, num_layers, num_local_experts, num_logical_experts
env,
world_size,
num_layers,
num_local_experts,
num_logical_experts,
eplb_communicator: str,
) -> None:
# Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel)
@@ -351,7 +425,8 @@ def _test_rearrange_expert_weights_with_redundancy(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_group_coordinator = get_tp_group()
ep_group = ep_group_coordinator.cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
@@ -387,6 +462,12 @@ def _test_rearrange_expert_weights_with_redundancy(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend=eplb_communicator,
expert_weights=expert_weights[0],
)
# Execute weight rearrangement
rearrange_expert_weights_inplace(
old_indices,
@@ -394,24 +475,33 @@ def _test_rearrange_expert_weights_with_redundancy(
expert_weights,
ep_group,
is_profile=False,
communicator=communicator,
)
# Verify the rearrangement result
verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
# Verify the rearrangement result
local_ok = verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
local_ok = (
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
world_size,
num_local_experts,
)
and local_ok
)
assert_verification_synced(
local_ok,
"Rearrange verification failed on at least one rank. See logs for details.",
)
@pytest.mark.parametrize(
@@ -437,8 +527,13 @@ def _test_rearrange_expert_weights_with_redundancy(
(4, 8, 8, 16),
],
)
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
def test_rearrange_expert_weights_with_redundancy(
world_size, num_layers, num_local_experts, num_logical_experts
world_size,
num_layers,
num_local_experts,
num_logical_experts,
eplb_communicator,
):
"""Test the functionality of rearranging expert weights with redundancy."""
@@ -450,6 +545,7 @@ def test_rearrange_expert_weights_with_redundancy(
num_layers,
num_local_experts,
num_logical_experts,
eplb_communicator,
)
@@ -464,7 +560,8 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_group_coordinator = get_tp_group()
ep_group = ep_group_coordinator.cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
@@ -494,24 +591,40 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend="torch_nccl",
expert_weights=expert_weights[0],
)
# Execute rearrangement (should be no change)
rearrange_expert_weights_inplace(
indices,
indices, # Same indices
expert_weights,
ep_group,
communicator,
is_profile=False,
)
# Verify that the weights have not changed
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg=f"""Layer {layer}, weight {weight_idx}
should remain unchanged""",
# Verify that the weights have not changed
local_ok = True
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
if not torch.equal(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
):
local_ok = False
print(
"test_rearrange_expert_weights_no_change failed: "
f"layer={layer}, weight_idx={weight_idx}",
flush=True,
)
assert_verification_synced(
local_ok,
"No-change EPLB verification failed on at least one rank.",
)
@pytest.mark.parametrize(
@@ -520,11 +633,13 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
(2, 2, 2, 3),
],
)
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
def test_async_transfer_layer_without_mtp(
world_size: int,
num_layers: int,
num_local_experts: int,
num_logical_experts: int,
eplb_communicator: str,
):
"""Exercise async EPLB transfer path without MTP/spec decode."""
@@ -537,6 +652,7 @@ def test_async_transfer_layer_without_mtp(
num_layers,
num_local_experts,
num_logical_experts,
eplb_communicator,
)
@@ -549,7 +665,10 @@ def test_rearrange_expert_weights_no_change(world_size):
if torch.accelerator.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run(_test_rearrange_expert_weights_no_change, world_size)
distributed_run(
_test_rearrange_expert_weights_no_change,
world_size,
)
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
@@ -563,7 +682,8 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_group_coordinator = get_tp_group()
ep_group = ep_group_coordinator.cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
@@ -600,23 +720,40 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend="torch_nccl",
expert_weights=expert_weights[0],
)
# Execute profile mode rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
communicator,
is_profile=True, # Profile mode
)
# In profile mode, the weights should remain unchanged
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg="In profile mode, the weights should remain unchanged",
# In profile mode, the weights should remain unchanged
local_ok = True
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
if not torch.equal(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
):
local_ok = False
print(
"test_rearrange_expert_weights_profile_mode failed: "
f"layer={layer}, weight_idx={weight_idx}",
flush=True,
)
assert_verification_synced(
local_ok,
"Profile-mode EPLB verification failed on at least one rank.",
)
@pytest.mark.parametrize("world_size", [2, 4])
@@ -625,4 +762,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
if torch.accelerator.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)
distributed_run(
_test_rearrange_expert_weights_profile_mode,
world_size,
)