[EPLB] Optimize EPLB with numpy (#29499)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Ilya Markov
2026-01-07 21:21:35 +01:00
committed by GitHub
parent 0ada960a20
commit 6170d47d22
8 changed files with 732 additions and 266 deletions

View File

@@ -310,3 +310,143 @@ if __name__ == "__main__":
print(phy2log) print(phy2log)
test_basic_rebalance() test_basic_rebalance()
def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
"""Create replicas indices mapping from phy2log"""
pr = torch.zeros_like(phy2log)
for layer in range(phy2log.shape[0]):
seen: dict[int, int] = {}
row = phy2log[layer].tolist()
for i, expert in enumerate(row):
r = seen.get(expert, 0)
pr[layer, i] = r
seen[expert] = r + 1
return pr
def _validate_intragpu_rearrangement(
old_global_expert_indices: torch.Tensor,
new_phy2log: torch.Tensor,
new_phy_replicas_idx: torch.Tensor,
post_phy2log: torch.Tensor,
post_phy_replicas_idx: torch.Tensor,
num_ranks: int,
slots_per_gpu: int,
):
# Per-GPU checks
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
old_seg = old_global_expert_indices[0, start:end]
new_seg = new_phy2log[0, start:end]
new_rnk = new_phy_replicas_idx[0, start:end]
post_seg = post_phy2log[0, start:end]
post_rnk = post_phy_replicas_idx[0, start:end]
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor):
pairs = list(zip(seg.tolist(), rnk.tolist()))
pairs.sort()
return pairs
assert sorted_pairs(post_seg, post_rnk) == sorted_pairs(new_seg, new_rnk), (
f"Per-GPU pairs of (expert,rank) must match new mapping for GPU {gpu_idx}"
)
# For experts that remain on the same GPU, the old slot is preserved
# for at least one occurrence; rank at that slot must be valid for that expert
old_list = old_seg.tolist()
new_list = new_seg.tolist()
post_list = post_seg.tolist()
remained = set(old_list) & set(new_list)
new_ranks_for_expert: dict[int, list[int]] = {}
for v, r in zip(new_list, new_rnk.tolist()):
new_ranks_for_expert.setdefault(v, []).append(r)
for expert in remained:
old_pos = old_list.index(expert)
assert post_list[old_pos] == expert, (
f"Expert {expert} on GPU {gpu_idx} should stay at old slot {old_pos}"
)
# Rank at preserved slot must be one of the ranks
# the expert has in new mapping
assert post_rnk.tolist()[old_pos] in new_ranks_for_expert[expert], (
f"Rank for expert {expert} at preserved slot on GPU {gpu_idx} "
"must come from new mapping"
)
@pytest.mark.parametrize(
"num_ranks, slots_per_gpu, old_phy2log, new_phy2log",
[
pytest.param(
# Setup: 2 GPUs, 4 slots each, 1 layer
# Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7]
# New mapping shuffles within GPU0 and brings 4,5 into GPU0.
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
2,
4,
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]),
torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]),
id="simple",
),
pytest.param(
# Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated)
# GPU1 -> [4, 5, 6, 1, 2]
# New mapping reorders within GPUs and moves some experts across GPUs,
# while still including duplicates:
# GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming)
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
2,
5,
torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]),
torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]),
id="duplicates",
),
pytest.param(
# Setup: 3 GPUs, 4 slots each (total 12 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 2, 3]
# GPU1 -> [0, 1, 2, 3]
# GPU2 -> [0, 1, 2, 3]
# New mapping decides to use one expert on 2 GPUs and shuffles
# experts on the third GPU,
# GPU0 new -> [0, 0, 0, 0]
# GPU1 new -> [0, 0, 0, 0]
# GPU2 new -> [1, 2, 3, 0]
3,
4,
torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]),
id="skewed_expert",
),
],
)
def test_preserve_intragpu_slots(
num_ranks: int,
slots_per_gpu: int,
old_phy2log: torch.Tensor,
new_phy2log: torch.Tensor,
):
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
)
# Shapes preserved
assert post_phy2log.shape == new_phy2log.shape
assert post_phy_replicas_idx.shape == phy_replicas_idx.shape
_validate_intragpu_rearrangement(
old_phy2log,
new_phy2log,
phy_replicas_idx,
post_phy2log,
post_phy_replicas_idx,
num_ranks,
slots_per_gpu,
)

View File

@@ -286,15 +286,17 @@ def _test_async_transfer_layer_without_mtp_worker(
device, device,
old_indices, old_indices,
) )
old_indices_cpu = old_indices.cpu()
new_indices_cpu = new_indices.cpu()
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]] expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device) cuda_stream = torch.cuda.Stream(device=device)
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run( is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer( transfer_layer(
old_global_expert_indices=old_indices, old_global_expert_indices=old_indices_cpu,
new_global_expert_indices=new_indices, new_global_expert_indices=new_indices_cpu,
expert_weights=expert_weights, expert_weights=expert_weights,
expert_weights_buffer=expert_buffer, expert_weights_buffer=expert_buffer,
ep_group=ep_group, ep_group=ep_group,
@@ -302,16 +304,15 @@ def _test_async_transfer_layer_without_mtp_worker(
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
) )
) )
cuda_stream.synchronize() cuda_stream.synchronize()
move_from_buffer( move_from_buffer(
expert_weights=expert_weights[layer_idx], expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer, expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged, is_unchanged=is_unchanged,
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
experts_recv_loc=experts_recv_loc, recv_metadata=recv_metadata,
new_indices=new_indices[layer_idx].tolist(), new_indices=new_indices_cpu[layer_idx],
ep_group=ep_group, ep_rank=ep_rank,
) )
verify_expert_weights_after_shuffle( verify_expert_weights_after_shuffle(

View File

@@ -69,6 +69,10 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism. Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead. This is turned off by default since it will cause communication overhead.
""" """
log_balancedness_interval: int = 1
"""
Interval for logging the balancedness.
"""
use_async: bool = False use_async: bool = False
""" """
Whether to use non-blocking EPLB. Whether to use non-blocking EPLB.
@@ -77,6 +81,14 @@ class EPLBConfig:
policy: EPLBPolicyOption = "default" policy: EPLBPolicyOption = "default"
"""The policy type for expert parallel load balancing (EPLB).""" """The policy type for expert parallel load balancing (EPLB)."""
@model_validator(mode="after")
def _validate_eplb_config(self) -> Self:
if self.use_async and self.policy != "default":
raise ValueError("Async EPLB is only supported with the default policy.")
if self.log_balancedness and self.log_balancedness_interval <= 0:
raise ValueError("log_balancedness_interval must be greater than 0.")
return self
@config @config
@dataclass @dataclass

View File

@@ -89,7 +89,7 @@ async def transfer_run_periodically(
( (
model_state.is_unchanged, model_state.is_unchanged,
model_state.is_received_locally, model_state.is_received_locally,
model_state.experts_recv_loc, model_state.recv_metadata,
) = await transfer_layer( ) = await transfer_layer(
old_global_expert_indices=model_state.physical_to_logical_map, old_global_expert_indices=model_state.physical_to_logical_map,
new_global_expert_indices=model_state.new_physical_to_logical_map, new_global_expert_indices=model_state.new_physical_to_logical_map,

View File

@@ -27,10 +27,10 @@ physical experts.
""" """
import threading import threading
import time
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np
import torch import torch
from torch.distributed import ProcessGroup, all_reduce from torch.distributed import ProcessGroup, all_reduce
@@ -46,7 +46,11 @@ from vllm.model_executor.models.interfaces import MixtureOfExperts
from .async_worker import start_async_worker from .async_worker import start_async_worker
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace from .rebalance_execute import (
RecvMetadata,
move_from_buffer,
rearrange_expert_weights_inplace,
)
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -164,20 +168,19 @@ class EplbModelState:
""" """
Whether the async EPLB needs to poll peers for buffer readiness. Whether the async EPLB needs to poll peers for buffer readiness.
""" """
is_unchanged: list[bool] is_unchanged: np.ndarray
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer. The size is same as the num of physical experts in the current layer.
""" """
is_received_locally: list[bool] is_received_locally: np.ndarray
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer. The size is same as the num of physical experts in the current layer.
""" """
experts_recv_loc: dict[int, int] recv_metadata: RecvMetadata
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer.
""" """
is_async_enabled: bool is_async_enabled: bool
""" """
@@ -507,9 +510,14 @@ class EplbState:
layer_to_transfer=0, layer_to_transfer=0,
rebalanced=False, rebalanced=False,
pending_global_ready_check=False, pending_global_ready_check=False,
is_unchanged=[], is_unchanged=np.array([]),
is_received_locally=[], is_received_locally=np.array([]),
experts_recv_loc={}, recv_metadata=RecvMetadata(
recv_primary_mask=np.array([]),
recv_count=0,
recv_expert_ids=np.array([]),
recv_dst_rows=np.array([]),
),
is_async_enabled=self.is_async, is_async_enabled=self.is_async,
cuda_device_index=self.cuda_device_index, cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map, new_physical_to_logical_map=new_physical_to_logical_map,
@@ -553,7 +561,12 @@ class EplbState:
for eplb_model_state in self.model_states.values(): for eplb_model_state in self.model_states.values():
eplb_model_state.expert_load_pass.zero_() eplb_model_state.expert_load_pass.zero_()
if log_stats: if (
log_stats
and self.expert_rearrangement_step
% self.parallel_config.eplb_config.log_balancedness_interval
== 0
):
# Sync the expert load pass for each model (main and drafter). # Sync the expert load pass for each model (main and drafter).
# expert_load_pass: (num_moe_layers, num_physical_experts) # expert_load_pass: (num_moe_layers, num_physical_experts)
expert_load_pass_list = self._sync_load_pass() expert_load_pass_list = self._sync_load_pass()
@@ -586,12 +599,15 @@ class EplbState:
if ep_group.rank() == 0: if ep_group.rank() == 0:
logger.info( logger.info(
"EPLB step: %d for model %s: avg_tokens=%.2f, " "EPLB step: %d for model %s: avg_tokens=%.2f, "
"max_tokens=%d, balancedness=%.4f", "max_tokens=%d, balancedness=%.4f, "
"steps until the next rearrangement: %d",
self.expert_rearrangement_step, self.expert_rearrangement_step,
eplb_model_state.model_name, eplb_model_state.model_name,
avg_tokens, avg_tokens,
max_tokens, max_tokens,
balancedness, balancedness,
self.expert_rearrangement_step_interval
- self.expert_rearrangement_step,
) )
# Update the expert load sliding window # Update the expert load sliding window
@@ -684,11 +700,14 @@ class EplbState:
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
ep_rank = ep_group.rank() ep_rank = ep_group.rank()
time_start = None start_event = None
end_event = None
is_main_rank = ep_rank == 0 is_main_rank = ep_rank == 0
if is_main_rank: if is_main_rank:
torch.cuda.synchronize() if not self.is_async or is_profile:
time_start = time.perf_counter() start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
logger.info( logger.info(
"Rearranging experts %s %s...", "Rearranging experts %s %s...",
"(async mode)" if self.is_async else "sync mode", "(async mode)" if self.is_async else "sync mode",
@@ -800,6 +819,7 @@ class EplbState:
num_groups, num_groups,
num_nodes, num_nodes,
num_gpus, num_gpus,
eplb_model_state.physical_to_logical_map,
) )
if not eplb_model_state.is_async_enabled or is_profile: if not eplb_model_state.is_async_enabled or is_profile:
@@ -848,17 +868,17 @@ class EplbState:
new_logical_replica_count new_logical_replica_count
) )
if is_main_rank: if is_main_rank:
assert time_start is not None assert start_event is not None
torch.cuda.synchronize() assert end_event is not None
time_end = time.perf_counter() end_event.record()
end_event.synchronize()
gpu_elapsed = start_event.elapsed_time(end_event) / 1000.0
logger.info( logger.info(
"Rearranged experts%sin %.2f seconds.", "Rearranged experts %s in %.2f s.",
" (profile) " if is_profile else " ", " (profile) " if is_profile else " ",
time_end - time_start, gpu_elapsed,
) )
else: else:
device = eplb_model_state.physical_to_logical_map.device
new_physical = new_physical_to_logical_map.to(device)
max_slots = eplb_model_state.logical_to_physical_map.shape[-1] max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad( padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map, new_logical_to_physical_map,
@@ -869,7 +889,10 @@ class EplbState:
eplb_model_state.logical_replica_count.device eplb_model_state.logical_replica_count.device
) )
eplb_model_state.new_physical_to_logical_map = new_physical # Move map to cpu in advance
eplb_model_state.new_physical_to_logical_map = (
new_physical_to_logical_map.cpu()
)
eplb_model_state.new_logical_to_physical_map = padded_logical eplb_model_state.new_logical_to_physical_map = padded_logical
eplb_model_state.new_logical_replica_count = new_replica eplb_model_state.new_logical_replica_count = new_replica
@@ -968,25 +991,30 @@ class EplbState:
stream = torch.cuda.current_stream(device=device_index) stream = torch.cuda.current_stream(device=device_index)
stream.wait_event(model_state.buffer_ready_event) stream.wait_event(model_state.buffer_ready_event)
model_state.buffer_ready_event = None model_state.buffer_ready_event = None
expert_weights = model_state.model.expert_weights[
model_state.layer_to_transfer
]
expert_weights_buffer = model_state.expert_buffer
new_indices = (
model_state.new_physical_to_logical_map[model_state.layer_to_transfer]
.cpu()
.numpy()
)
move_from_buffer( move_from_buffer(
expert_weights=model_state.model.expert_weights[ expert_weights=expert_weights,
model_state.layer_to_transfer expert_weights_buffers=expert_weights_buffer,
],
expert_weights_buffer=model_state.expert_buffer,
is_unchanged=model_state.is_unchanged, is_unchanged=model_state.is_unchanged,
is_received_locally=model_state.is_received_locally, is_received_locally=model_state.is_received_locally,
experts_recv_loc=model_state.experts_recv_loc, recv_metadata=model_state.recv_metadata,
new_indices=model_state.new_physical_to_logical_map[ new_indices=new_indices,
model_state.layer_to_transfer ep_rank=ep_group.rank(),
].tolist(),
ep_group=ep_group,
) )
transferred_layer = model_state.layer_to_transfer transferred_layer = model_state.layer_to_transfer
self._update_layer_mapping_from_new(model_state, transferred_layer) self._update_layer_mapping_from_new(model_state, transferred_layer)
# After the main thread consumes, advance layer_to_transfer # After the main thread consumes, advance layer_to_transfer
model_state.layer_to_transfer += 1 model_state.layer_to_transfer += 1
model_state.ep_buffer_ready = 0 model_state.ep_buffer_ready = 0
logger.info( logger.debug(
"model %s successfully move_to_workspace layer %d", "model %s successfully move_to_workspace layer %d",
model_state.model_name, model_state.model_name,
transferred_layer, transferred_layer,

View File

@@ -16,6 +16,7 @@ class AbstractEplbPolicy(ABC):
num_groups: int, num_groups: int,
num_nodes: int, num_nodes: int,
num_ranks: int, num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Entry point for expert-parallelism load balancer. Entry point for expert-parallelism load balancer.
@@ -28,7 +29,9 @@ class AbstractEplbPolicy(ABC):
num_groups: number of expert groups num_groups: number of expert groups
num_nodes: number of server nodes num_nodes: number of server nodes
num_ranks: number of ranks, must be a multiple of `num_nodes` num_ranks: number of ranks, must be a multiple of `num_nodes`
old_global_expert_indices: [layers, num_logical_experts], the old global
expert indices. Used to avoid unnecessary weight copying
for experts moving within one rank.
Returns: Returns:
physical_to_logical_map: [layers, num_replicas], the expert physical_to_logical_map: [layers, num_replicas], the expert
index of each replica index of each replica

View File

@@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns: Returns:
phy2log: [X, num_phy], logical expert id of each physical expert phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank replica_idx: [X, num_phy], the index of the replica for each logical expert
logcnt: [X, num_log], number of replicas for each logical expert logcnt: [X, num_log], number of replicas for each logical expert
""" """
n, num_log = weight.shape n, num_log = weight.shape
@@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_redundant >= 0 assert num_redundant >= 0
device = weight.device device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device) arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy): for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices] replica_idx[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1 logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt return phy2log, replica_idx, logcnt
@classmethod @classmethod
def rebalance_experts_hierarchical( def rebalance_experts_hierarchical(
@@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns: Returns:
phy2log: [layers, num_replicas], the expert phy2log: [layers, num_replicas], the expert
index of each replica index of each replica
log2phy: [layers, num_logical_experts, X], pphy_replicas_idx: [layers, num_logical_experts, X],
the replica indices for each expert the replica indices for each expert
logcnt: [layers, num_logical_experts], number of logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert physical replicas for each logical expert
@@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
tokens_per_mlog = weight.gather(-1, mlog2log).view( tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes -1, num_logical_experts // num_nodes
) )
phy2mlog, phyrank, mlogcnt = cls.replicate_experts( phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes tokens_per_mlog, num_physical_experts // num_nodes
) )
@@ -203,9 +203,109 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
).view(1, -1, 1) ).view(1, -1, 1)
).flatten(-2) ).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog) pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt return pphy2log, pphy_replicas_idx, logcnt
@classmethod
def preserve_intragpu_slots(
cls,
phy2log: torch.Tensor,
phy_replicas_idx: torch.Tensor,
num_ranks: int,
old_global_expert_indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU
fill any remaining available slots. This is applied only when the number of GPUs
is unchanged and the slots per GPU remain the same between
the old and new mappings.
"""
device = phy2log.device
num_phy_experts = phy2log.shape[1]
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
return phy2log, phy_replicas_idx
# Move to CPU and convert to NumPy for processing
new_phy2log_np = phy2log.cpu().numpy()
replicas_idx_np = phy_replicas_idx.cpu().numpy()
old_phy2log_np = old_global_expert_indices.cpu().numpy()
slots_per_gpu = num_phy_experts // num_ranks
num_layers = new_phy2log_np.shape[0]
post_phy2log_np = new_phy2log_np.copy()
post_phy_replicas_idx_np = replicas_idx_np.copy()
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
# Experts across all layers for this GPU
old_local = old_phy2log_np[:, start:end] # [layers, slots]
new_local = new_phy2log_np[:, start:end] # [layers, slots]
new_ridx = replicas_idx_np[:, start:end] # [layers, slots]
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
# First pass: preserve same-logical experts in their previous slots
for slot_idx in range(slots_per_gpu):
# matches: [layers, slots], True where new local experts have
# the same logical value as the old from 'slot_idx' and not checked yet
matches = (new_local == old_local[:, slot_idx][:, None]) & (
~used_new_indices
)
has_any = matches.any(axis=1)
if np.any(has_any):
first_idx = np.argmax(matches, axis=1)
layer_indices = np.nonzero(has_any)[0]
matched_new_positions = first_idx[layer_indices]
post_phy2log_np[layer_indices, start + slot_idx] = new_local[
layer_indices, matched_new_positions
]
post_phy_replicas_idx_np[layer_indices, start + slot_idx] = (
new_ridx[layer_indices, matched_new_positions]
)
used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, slot_idx] = True
# Second pass: fill remaining slots with remaining new experts
remaining_mask = ~used_new_indices # [layers, slots]
fill_mask = ~preserved_positions # [layers, slots]
if remaining_mask.any() and fill_mask.any():
idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
# Sentinel value for unavailable positions.
large = slots_per_gpu + 1
# Priorities: keep original index for available spots, set sentinel
# for unavailable; lower is earlier.
remaining_priority = np.where(remaining_mask, idx_base, large)
fill_priority = np.where(fill_mask, idx_base, large)
# Sort to get ordered indices of available src/dst positions per layer.
remaining_indices = np.argsort(remaining_priority, axis=1)
fill_indices = np.argsort(fill_priority, axis=1)
# Fill count per layer (cannot exceed either side).
remaining_counts = remaining_mask.sum(axis=1)
fill_counts = fill_mask.sum(axis=1)
take_counts = np.minimum(remaining_counts, fill_counts)
# Assign remaining new experts to remaining slots per layer.
for layer_idx in range(num_layers):
k = int(take_counts[layer_idx])
if k <= 0:
continue
src_pos = remaining_indices[layer_idx, :k]
dst_pos = fill_indices[layer_idx, :k]
post_phy2log_np[layer_idx, start + dst_pos] = new_local[
layer_idx, src_pos
]
post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[
layer_idx, src_pos
]
# Convert back to torch and move to original device
post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device)
return post_phy2log, post_phy_replicas_idx
@classmethod @classmethod
def rebalance_experts( def rebalance_experts(
@@ -215,6 +315,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
num_groups: int, num_groups: int,
num_nodes: int, num_nodes: int,
num_ranks: int, num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Entry point for expert-parallelism load balancer. Entry point for expert-parallelism load balancer.
@@ -228,7 +329,9 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
num_nodes: number of server nodes, where the intra-node network num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster (e.g, NVLink) is faster
num_ranks: number of ranks, must be a multiple of `num_nodes` num_ranks: number of ranks, must be a multiple of `num_nodes`
old_global_expert_indices: [layers, num_logical_experts], the old global
expert indices. Used to avoid unnecessary weight copying
for experts moving within one rank.
Returns: Returns:
phy2log: [layers, num_replicas], the expert phy2log: [layers, num_replicas], the expert
index of each replica index of each replica
@@ -241,14 +344,23 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
weight = weight.float() weight = weight.float()
if num_groups % num_nodes == 0: if num_groups % num_nodes == 0:
# use hierarchical load-balance policy # use hierarchical load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_ranks weight, num_replicas, num_groups, num_nodes, num_ranks
) )
else: else:
# use global load-balance policy # use global load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_ranks weight, num_replicas, 1, 1, num_ranks
) )
# Optional postprocessing to preserve slots for experts moving
# within the same GPU
# Only apply when the number of GPUs and slots per GPU remain unchanged.
# Helps to avoid unnecessary weight copying when experts move
# within the same GPU.
if old_global_expert_indices is not None:
phy2log, phy_replicas_idx = cls.preserve_intragpu_slots(
phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices
)
num_redundant_experts = num_replicas - num_logical_experts num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1 maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full( log2phy: torch.Tensor = torch.full(
@@ -259,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
) )
log2phy.view(num_layers, -1).scatter_( log2phy.view(num_layers, -1).scatter_(
-1, -1,
phy2log * maxlogcnt + phyrank, phy2log * maxlogcnt + phy_replicas_idx,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1 num_layers, -1
), ),

View File

@@ -6,9 +6,10 @@ The actual execution of the rearrangement.
This involves the exchange of expert weights between GPUs. This involves the exchange of expert weights between GPUs.
""" """
from collections.abc import Iterable, MutableSequence, Sequence from collections.abc import Iterable, Sequence
from functools import partial from dataclasses import dataclass
import numpy as np
import torch import torch
from torch.distributed import ( from torch.distributed import (
P2POp, P2POp,
@@ -18,214 +19,318 @@ from torch.distributed import (
get_global_rank, get_global_rank,
) )
from vllm.logger import init_logger
def idx_local_to_global( logger = init_logger(__name__)
local_idx: int,
local_cnt: int,
ep_rank: int,
) -> int:
"""
Convert a local expert index to a global expert index.
"""
return ep_rank * local_cnt + local_idx
def idx_global_to_local( @dataclass
global_idx: int, class RecvMetadata:
local_cnt: int, """Metadata describing remote receives during EPLB rebalancing."""
ep_rank: int,
) -> int: recv_primary_mask: np.ndarray
""" """Mask of (num_local_experts,) indicating primary experts received."""
Convert a global expert index to a local expert index. recv_count: int
""" """Number of received experts for the layer."""
return global_idx - ep_rank * local_cnt recv_expert_ids: np.ndarray
"""Expert ids (num_local_experts,) of remote primary experts."""
recv_dst_rows: np.ndarray
"""Target expert indices (num_local_experts,) in local tensors to send."""
def global_idx_to_rank( # Type alias for the result of move_to_buffer or transfer_layer
global_idx: int, MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata]
local_cnt: int,
) -> int:
"""
Convert a global expert index to a rank index.
"""
return global_idx // local_cnt
def get_ep_ranks_with_expert( def get_ep_ranks_with_experts_batch(
idx: int, expert_ids: np.ndarray,
num_local_experts: int, num_local_experts: int,
old_indices: Sequence[int], old_indices: np.ndarray,
new_indices: Sequence[int], new_indices: np.ndarray,
) -> tuple[MutableSequence[int], MutableSequence[int]]: ) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
""" """
Get the ranks of the experts that need to be exchanged. Get the ranks of the experts that need to be exchanged.
Args: Args:
idx: The index of the expert. expert_ids: 1D array of expert indices to query.
num_local_experts: The number of local experts. num_local_experts: The number of local experts.
old_indices: The old indices of the experts. old_indices: The old indices of the experts.
new_indices: The new indices of the experts. new_indices: The new indices of the experts.
Returns: Returns:
A tuple of two lists: A tuple of two dictionaries mapping expert_id to:
- The ranks of the experts that need to be sent. - ranks_to_send: The ranks that have this expert and need to send.
- The ranks of the experts that need to be received. - ranks_to_recv: The ranks that need to receive this expert.
""" """
global2rank = partial( ranks_to_send_map: dict[int, list[int]] = {}
global_idx_to_rank, ranks_to_recv_map: dict[int, list[int]] = {}
local_cnt=num_local_experts,
)
ranks_to_send: list[int] = [] # Fast path: if no experts, return empty dicts
ranks_to_recv: list[int] = [] if expert_ids.size == 0:
return ranks_to_send_map, ranks_to_recv_map
for i, e in enumerate(old_indices): unique_experts = np.unique(expert_ids)
if e == idx: num_positions = len(old_indices)
rank = global2rank(i) position_indices = np.arange(num_positions, dtype=np.int32)
if not ranks_to_send or ranks_to_send[-1] != rank:
ranks_to_send.append(rank)
for i, e in enumerate(new_indices): # Vectorized approach: find all positions matching any query expert in one pass
if e == idx: # Use np.isin to get boolean masks for all relevant positions at once
rank = global2rank(i) old_relevant_mask = np.isin(old_indices, unique_experts)
if not ranks_to_recv or ranks_to_recv[-1] != rank: new_relevant_mask = np.isin(new_indices, unique_experts)
ranks_to_recv.append(rank)
# Remove those ranks that can get this expert locally. # Process old_indices (send ranks)
ranks_to_send_set = set(ranks_to_send) if np.any(old_relevant_mask):
ranks_to_recv_actual = [ old_relevant_positions = position_indices[old_relevant_mask]
rank for rank in ranks_to_recv if rank not in ranks_to_send_set old_relevant_experts = old_indices[old_relevant_mask]
] old_relevant_ranks = old_relevant_positions // num_local_experts
return ranks_to_send, ranks_to_recv_actual # Sort by expert first, then by position (to maintain first-appearance order)
sort_order = np.lexsort((old_relevant_positions, old_relevant_experts))
sorted_experts = old_relevant_experts[sort_order]
sorted_ranks = old_relevant_ranks[sort_order]
# Find boundaries where expert changes
expert_boundaries = np.concatenate(
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
)
# For each expert, extract unique ranks in order of first appearance
for i in range(len(expert_boundaries) - 1):
start, end = expert_boundaries[i], expert_boundaries[i + 1]
expert = int(sorted_experts[start])
expert_ranks = sorted_ranks[start:end]
# Get unique ranks preserving order
_, unique_idx = np.unique(expert_ranks, return_index=True)
unique_ranks = expert_ranks[np.sort(unique_idx)]
ranks_to_send_map[expert] = unique_ranks.tolist()
# Process new_indices (recv ranks)
if np.any(new_relevant_mask):
new_relevant_positions = position_indices[new_relevant_mask]
new_relevant_experts = new_indices[new_relevant_mask]
new_relevant_ranks = new_relevant_positions // num_local_experts
# Sort by expert first, then by position
sort_order = np.lexsort((new_relevant_positions, new_relevant_experts))
sorted_experts = new_relevant_experts[sort_order]
sorted_ranks = new_relevant_ranks[sort_order]
# Find boundaries where expert changes
expert_boundaries = np.concatenate(
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
)
# For each expert, extract unique ranks and exclude local copies
for i in range(len(expert_boundaries) - 1):
start, end = expert_boundaries[i], expert_boundaries[i + 1]
expert = int(sorted_experts[start])
expert_ranks = sorted_ranks[start:end]
# Get unique ranks preserving order
_, unique_idx = np.unique(expert_ranks, return_index=True)
unique_ranks = expert_ranks[np.sort(unique_idx)]
# Remove ranks that have local copies (in send map)
send_ranks_set = set(ranks_to_send_map.get(expert, []))
recv_ranks_actual = [
int(r) for r in unique_ranks if r not in send_ranks_set
]
ranks_to_recv_map[expert] = recv_ranks_actual
# Handle experts that only appear in old (send only) or new (recv only)
for expert in unique_experts:
expert = int(expert)
if expert not in ranks_to_send_map:
ranks_to_send_map[expert] = []
if expert not in ranks_to_recv_map:
ranks_to_recv_map[expert] = []
return ranks_to_send_map, ranks_to_recv_map
def move_to_buffer( def move_to_buffer(
num_local_experts: int, num_local_experts: int,
old_indices: Sequence[int], old_indices: np.ndarray,
new_indices: Sequence[int], new_indices: np.ndarray,
expert_weights: Iterable[torch.Tensor], expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor], expert_weights_buffers: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None, cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup, ep_group: ProcessGroup,
) -> tuple[list[bool], list[bool], dict[int, int]]: ) -> MoveToBufferResult:
""" """
Perform expert weights rearrangement of one layer. Rearranges expert weights during EPLB rebalancing.
Args:
num_local_experts: Number of local experts.
old_indices: (num_experts_total,) ndarray of current (old)
global-to-local expert assignments.
new_indices: (num_experts_total,) ndarray of desired (new)
global-to-local assignments after rebalance.
expert_weights: Original expert weights for the layer.
expert_weights_buffers: Intermediate buffers (one per tensor).
cuda_stream: CUDA stream for async copies (can be None for sync mode).
ep_group: Distributed process group for expert parallel comms.
Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
is unchanged after rebalance.
is_received_locally (np.ndarray): (num_local_experts,), True where a row
can be updated from local data.
RecvMetadata: Metadata needed for completing remote weight transfers.
""" """
assert old_indices.shape == new_indices.shape
ep_rank = ep_group.rank() ep_rank = ep_group.rank()
local2global = partial(
idx_local_to_global, recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
local_cnt=num_local_experts, send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
ep_rank=ep_rank, send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32)
base = ep_rank * num_local_experts
local_rows = np.arange(num_local_experts, dtype=np.int32)
local_global = base + local_rows
old_local_expert_ids = old_indices[local_global]
new_local_expert_ids = new_indices[local_global]
# Unchanged mask
is_unchanged = old_local_expert_ids == new_local_expert_ids
# Local receive eligibility
new_valid = new_local_expert_ids != -1
can_recv_local = np.isin(
new_local_expert_ids, old_local_expert_ids, assume_unique=False
)
is_received_locally = np.logical_or(
is_unchanged, np.logical_and(new_valid, can_recv_local)
) )
# 0. Do nothing for experts that did not change. # Send map: first src row per unique expert present locally in old mapping
is_unchanged = [ send_count = 0
old_indices[local2global(i)] == new_indices[local2global(i)] valid_old = old_local_expert_ids != -1
for i in range(num_local_experts) if np.any(valid_old):
] uniq_experts, first_idx = np.unique(
old_local_expert_ids[valid_old], return_index=True
)
filtered_rows = local_rows[valid_old]
src_rows = filtered_rows[first_idx]
send_count = int(uniq_experts.shape[0])
send_expert_ids[:send_count] = uniq_experts
send_src_rows[:send_count] = src_rows
# 1. Perform weight copy inside the local rank. # Recv map: primary dst per unique expert needed remotely
is_received_locally = is_unchanged[:] recv_count = 0
for src in range(num_local_experts): need_recv_mask = np.logical_and(~is_received_locally, new_valid)
src_global = local2global(src) if np.any(need_recv_mask):
for dst in range(num_local_experts): desired_experts = new_local_expert_ids[need_recv_mask]
dst_global = local2global(dst) desired_dsts = local_rows[need_recv_mask]
if is_received_locally[dst]: uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True)
continue dst_rows = desired_dsts[uniq_indices]
if old_indices[src_global] == -1 or new_indices[dst_global] == -1: recv_count = int(uniq_recv_experts.shape[0])
continue recv_expert_ids[:recv_count] = uniq_recv_experts
if old_indices[src_global] == new_indices[dst_global]: recv_dst_rows[:recv_count] = dst_rows
is_received_locally[dst] = True recv_primary_mask[dst_rows] = True
for weight, buffer in zip(expert_weights, expert_weights_buffer):
with torch.cuda.stream(cuda_stream): eligible_local_buffer_mask = np.logical_and(~is_unchanged, is_received_locally)
buffer[dst].copy_(weight[src], non_blocking=True)
# 1. Local moves into tmp buffers
if bool(eligible_local_buffer_mask.any()) and send_count > 0:
dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist()
expert_to_src_map = dict(
zip(send_expert_ids[:send_count], send_src_rows[:send_count])
)
for dst in dest_indices:
expert = new_local_expert_ids[dst]
src_local = expert_to_src_map.get(expert, -1)
if src_local != -1:
for w, b in zip(expert_weights, expert_weights_buffers):
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = [] p2p_ops: list[P2POp] = []
# 2. Initiate sending of weights. # Pre-compute global ranks mapping
experts_send_loc: dict[int, int] = {} ep_size = ep_group.size()
for src in range(num_local_experts): rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
expert = old_indices[local2global(src)]
if expert == -1:
continue
if expert in experts_send_loc:
continue
experts_send_loc[expert] = src
# We need to sort here to match send/recv # 2. Post sends
for expert, src in sorted(experts_send_loc.items()): if send_count > 0:
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( experts = send_expert_ids[:send_count]
expert, srcs = send_src_rows[:send_count]
order = np.argsort(experts, kind="stable")
experts = experts[order]
srcs = srcs[order]
send_map, recv_map = get_ep_ranks_with_experts_batch(
experts,
num_local_experts, num_local_experts,
old_indices, old_indices,
new_indices, new_indices,
) )
# Calculate the ranks to send by this rank for expert, src in zip(experts.tolist(), srcs.tolist()):
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) ranks_to_send = send_map[expert]
sender_pos = ranks_to_send.index(ep_rank) ranks_to_recv = recv_map[expert]
recv_begin = sender_pos * num_dst_per_sender if not ranks_to_send or not ranks_to_recv:
recv_end = recv_begin + num_dst_per_sender continue
recv_ranks = ranks_to_recv[recv_begin:recv_end] num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
sender_pos = ranks_to_send.index(ep_rank)
recv_begin = sender_pos * num_dst_per_sender
recv_end = recv_begin + num_dst_per_sender
recv_ranks = ranks_to_recv[recv_begin:recv_end]
remainder_start = len(ranks_to_send) * num_dst_per_sender
recver_pos = remainder_start + sender_pos
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in expert_weights
]
# Tackle remainders # 3. Post recvs
remainder_start = len(ranks_to_send) * num_dst_per_sender if recv_count > 0:
recver_pos = remainder_start + sender_pos experts = recv_expert_ids[:recv_count]
if recver_pos < len(ranks_to_recv): dsts = recv_dst_rows[:recv_count]
recv_ranks.append(ranks_to_recv[recver_pos]) order = np.argsort(experts, kind="stable")
experts = experts[order]
dsts = dsts[order]
for dst in recv_ranks: send_map, recv_map = get_ep_ranks_with_experts_batch(
dst_global = get_global_rank(ep_group, dst) experts,
num_local_experts,
old_indices,
new_indices,
)
for expert, dst in zip(experts.tolist(), dsts.tolist()):
ranks_to_send = send_map[expert]
ranks_to_recv = recv_map[expert]
if not ranks_to_send or not ranks_to_recv:
continue
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
recver_pos = ranks_to_recv.index(ep_rank)
remainder_start = len(ranks_to_send) * num_dst_per_sender
if recver_pos < remainder_start:
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = rank_to_global[src]
p2p_ops += [ p2p_ops += [
P2POp( P2POp(
torch.distributed.isend, torch.distributed.irecv,
weight[src], b[dst],
dst_global, src_global,
) )
for weight in expert_weights for b in expert_weights_buffers
] ]
# 3. Initiate receiving of weights.
experts_recv_loc: dict[int, int] = {}
for dst in range(num_local_experts):
if is_received_locally[dst]:
continue
expert = new_indices[local2global(dst)]
if expert == -1:
continue
if expert in experts_recv_loc:
continue
experts_recv_loc[expert] = dst
# We need to sort here to match send/recv
for expert, dst in sorted(experts_recv_loc.items()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
new_indices,
)
# Calculate the rank to recv by this rank
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
recver_pos = ranks_to_recv.index(ep_rank)
remainder_start = len(ranks_to_send) * num_dst_per_sender
if recver_pos < remainder_start:
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = get_global_rank(ep_group, src)
p2p_ops += [
P2POp(
torch.distributed.irecv,
weight[dst],
src_global,
)
for weight in expert_weights_buffer
]
# 4. Execute the P2P operations. The real communication happens here. # 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None: if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream): with torch.cuda.stream(cuda_stream):
@@ -237,38 +342,95 @@ def move_to_buffer(
for req in reqs: for req in reqs:
req.wait() req.wait()
# wait for the communication to finish # wait for the communication to finish
return is_unchanged, is_received_locally, experts_recv_loc return (
is_unchanged,
is_received_locally,
RecvMetadata(
recv_primary_mask=recv_primary_mask,
recv_count=recv_count,
recv_expert_ids=recv_expert_ids,
recv_dst_rows=recv_dst_rows,
),
)
def move_from_buffer( def move_from_buffer(
expert_weights: Iterable[torch.Tensor], expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: list[torch.Tensor], expert_weights_buffers: list[torch.Tensor],
is_unchanged: list[bool], is_unchanged: np.ndarray,
is_received_locally: list[bool], is_received_locally: np.ndarray,
experts_recv_loc: dict[int, int], recv_metadata: RecvMetadata,
new_indices: Sequence[int], new_indices: np.ndarray,
ep_group: ProcessGroup, ep_rank: int,
) -> None: ) -> None:
ep_rank = ep_group.rank() """
num_local_experts = len(is_unchanged) Copies expert weights from communication buffers back to the target weight tensors
after EPLB rebalancing.
local2global = partial( Args:
idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank expert_weights: List of the actual MoE layer weights used in the execution.
expert_weights_buffers: Intermediate buffers containing the experts weights
after the transfer is completed.
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
is_received_locally: (num_local_experts,), True where a row is updated locally.
recv_metadata: RecvMetadata containing remote receive metadata.
new_indices: (num_experts_total,) mapping from local rows to desired
(possibly global) expert id, after rebalance.
ep_rank: Rank of the process in the expert parallel group.
"""
recv_primary_mask = recv_metadata.recv_primary_mask
recv_count = recv_metadata.recv_count
recv_expert_ids = recv_metadata.recv_expert_ids
recv_dst_rows = recv_metadata.recv_dst_rows
num_local_experts = is_unchanged.shape[0]
# Mask for rows to copy back from buffers:
# copy if locally received OR remote primary recv
copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
dest_mask_np = np.logical_and(~is_unchanged, copy_mask)
if bool(dest_mask_np.any()):
dest_indices = np.nonzero(dest_mask_np)[0].tolist()
for dst in dest_indices:
for w, b in zip(expert_weights, expert_weights_buffers):
w[dst].copy_(b[dst], non_blocking=True)
if recv_count == 0:
return
# Duplicate remote received rows to non-primary duplicate dsts
base = ep_rank * num_local_experts
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
duplicate_mask = np.logical_and(
np.logical_and(~is_unchanged, ~is_received_locally),
np.logical_and(~recv_primary_mask, local_experts != -1),
) )
# All received experts are unique in the destination, so no need to copy duplicates
if not bool(duplicate_mask.any()):
return
for dst in range(num_local_experts): dup_dst_rows = np.nonzero(duplicate_mask)[0]
if is_unchanged[dst]: dup_experts = local_experts[dup_dst_rows]
continue
if is_received_locally[dst]: prim_experts = recv_expert_ids[:recv_count]
for weight, buffer in zip(expert_weights, expert_weights_buffer): prim_dsts = recv_dst_rows[:recv_count]
weight[dst].copy_(buffer[dst], non_blocking=True) order = np.argsort(prim_experts, kind="stable")
else: prim_experts_sorted = prim_experts[order]
expert = new_indices[local2global(dst)] prim_dsts_sorted = prim_dsts[order]
if expert == -1: pos = np.searchsorted(prim_experts_sorted, dup_experts)
continue valid = np.logical_and(
src = experts_recv_loc[expert] pos < prim_experts_sorted.shape[0],
for weight, buffer in zip(expert_weights, expert_weights_buffer): prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
weight[dst].copy_(buffer[src], non_blocking=True) == dup_experts,
)
if not bool(valid.any()):
return
matched_dst_rows = dup_dst_rows[valid]
matched_src_rows = prim_dsts_sorted[pos[valid]]
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
for w in expert_weights:
w[dst].copy_(w[src], non_blocking=True)
async def transfer_layer( async def transfer_layer(
@@ -281,7 +443,7 @@ async def transfer_layer(
layer: int = 0, layer: int = 0,
cuda_stream: torch.cuda.Stream | None = None, cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> tuple[list[bool], list[bool], dict[int, int]]: ) -> MoveToBufferResult:
""" """
Rearranges the expert weights in place according to the new expert indices. Rearranges the expert weights in place according to the new expert indices.
@@ -299,6 +461,13 @@ async def transfer_layer(
is_profile (bool): If `True`, do not perform any actual weight copy. is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers. communications to reserve enough memory for the buffers.
Returns:
is_unchanged (np.ndarray): (1, num_local_experts), True where expert
is left unchanged.
is_received_locally (np.ndarray): (1, num_local_experts), True where expert
can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers.
""" """
ep_size = ep_group.size() ep_size = ep_group.size()
if rank_mapping is not None: if rank_mapping is not None:
@@ -323,16 +492,19 @@ async def transfer_layer(
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
assert num_physical_experts == ep_size * num_local_physical_experts assert num_physical_experts == ep_size * num_local_physical_experts
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( old_global_expert_indices_np = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_np = new_global_expert_indices.cpu().numpy()
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts, num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices[layer].tolist(), old_indices=old_global_expert_indices_np[layer],
new_indices=new_global_expert_indices[layer].tolist(), new_indices=new_global_expert_indices_np[layer],
expert_weights=expert_weights[layer], expert_weights=expert_weights[layer],
expert_weights_buffer=expert_weights_buffer, expert_weights_buffers=expert_weights_buffer,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
ep_group=ep_group, ep_group=ep_group,
) )
return is_unchanged, is_received_locally, experts_recv_loc return is_unchanged, is_received_locally, recv_metadata
def rearrange_expert_weights_inplace( def rearrange_expert_weights_inplace(
@@ -388,19 +560,17 @@ def rearrange_expert_weights_inplace(
ep_size = ep_group.size() ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts assert num_physical_experts == ep_size * num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange. first_layer_weights = list(expert_weights[0])
# Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers # NOTE: Currently we assume the same weights across different layers
# have the same shape. # have the same shape.
expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]] weights_buffer: list[torch.Tensor] = [
torch.empty_like(w) for w in first_layer_weights
]
if is_profile: if is_profile:
# Maximum send size is to send all local experts to all ranks, # Reserve communication buffers via a minimal dummy all_gather on first layer
# So we use a dummy `all_gather` to reserve enough communication buffer for weight, buffer in zip(expert_weights[0], weights_buffer):
for weight, buffer in zip(expert_weights[0], expert_weights_buffer):
# A `/dev/null`-like buffer to avoid real memory allocation
dummy_recv_buffer = [buffer for _ in range(ep_size)] dummy_recv_buffer = [buffer for _ in range(ep_size)]
# NOTE(bowen): Needed this barrier to avoid OOM during actual
# execution. I'm not very sure why this is needed
torch.distributed.barrier() torch.distributed.barrier()
all_gather( all_gather(
dummy_recv_buffer, dummy_recv_buffer,
@@ -409,32 +579,32 @@ def rearrange_expert_weights_inplace(
) )
return return
old_global_expert_indices_cpu = old_global_expert_indices.cpu()
new_global_expert_indices_cpu = new_global_expert_indices.cpu()
# NOTE(bowen): We need this synchronize to run, but I don't know why. # NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you! # If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize() torch.cuda.synchronize()
for layer in range(num_moe_layers): old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
for layer_idx in range(num_moe_layers):
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts, num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices_cpu[layer].tolist(), old_indices=old_global_expert_indices_cpu[layer_idx],
new_indices=new_global_expert_indices_cpu[layer].tolist(), new_indices=new_global_expert_indices_cpu[layer_idx],
expert_weights=expert_weights[layer], expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_weights_buffer, expert_weights_buffers=weights_buffer,
cuda_stream=None, cuda_stream=None,
ep_group=ep_group, ep_group=ep_group,
) )
move_from_buffer( move_from_buffer(
expert_weights=expert_weights[layer], expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_weights_buffer, expert_weights_buffers=weights_buffer,
is_unchanged=is_unchanged, is_unchanged=is_unchanged,
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
experts_recv_loc=experts_recv_loc, recv_metadata=recv_metadata,
new_indices=new_global_expert_indices[layer].tolist(), new_indices=new_global_expert_indices_cpu[layer_idx],
ep_group=ep_group, ep_rank=ep_group.rank(),
) )
@@ -526,4 +696,4 @@ def _map_new_expert_indices_with_rank_mapping(
return mapped_expert_indices return mapped_expert_indices
__all__ = ["transfer_layer", "move_from_buffer"] __all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"]