diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index a53a61840..c14cf5efe 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -310,3 +310,143 @@ if __name__ == "__main__": print(phy2log) test_basic_rebalance() + + +def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor: + """Create replicas indices mapping from phy2log""" + pr = torch.zeros_like(phy2log) + for layer in range(phy2log.shape[0]): + seen: dict[int, int] = {} + row = phy2log[layer].tolist() + for i, expert in enumerate(row): + r = seen.get(expert, 0) + pr[layer, i] = r + seen[expert] = r + 1 + return pr + + +def _validate_intragpu_rearrangement( + old_global_expert_indices: torch.Tensor, + new_phy2log: torch.Tensor, + new_phy_replicas_idx: torch.Tensor, + post_phy2log: torch.Tensor, + post_phy_replicas_idx: torch.Tensor, + num_ranks: int, + slots_per_gpu: int, +): + # Per-GPU checks + for gpu_idx in range(num_ranks): + start = gpu_idx * slots_per_gpu + end = start + slots_per_gpu + old_seg = old_global_expert_indices[0, start:end] + new_seg = new_phy2log[0, start:end] + new_rnk = new_phy_replicas_idx[0, start:end] + post_seg = post_phy2log[0, start:end] + post_rnk = post_phy_replicas_idx[0, start:end] + + # Pairwise equality for (expert, rank) pairs to ensure nothing is lost + def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor): + pairs = list(zip(seg.tolist(), rnk.tolist())) + pairs.sort() + return pairs + + assert sorted_pairs(post_seg, post_rnk) == sorted_pairs(new_seg, new_rnk), ( + f"Per-GPU pairs of (expert,rank) must match new mapping for GPU {gpu_idx}" + ) + + # For experts that remain on the same GPU, the old slot is preserved + # for at least one occurrence; rank at that slot must be valid for that expert + old_list = old_seg.tolist() + new_list = new_seg.tolist() + post_list = post_seg.tolist() + remained = set(old_list) & set(new_list) + new_ranks_for_expert: dict[int, list[int]] = {} + for v, r in zip(new_list, new_rnk.tolist()): + new_ranks_for_expert.setdefault(v, []).append(r) + for expert in remained: + old_pos = old_list.index(expert) + assert post_list[old_pos] == expert, ( + f"Expert {expert} on GPU {gpu_idx} should stay at old slot {old_pos}" + ) + # Rank at preserved slot must be one of the ranks + # the expert has in new mapping + assert post_rnk.tolist()[old_pos] in new_ranks_for_expert[expert], ( + f"Rank for expert {expert} at preserved slot on GPU {gpu_idx} " + "must come from new mapping" + ) + + +@pytest.mark.parametrize( + "num_ranks, slots_per_gpu, old_phy2log, new_phy2log", + [ + pytest.param( + # Setup: 2 GPUs, 4 slots each, 1 layer + # Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7] + # New mapping shuffles within GPU0 and brings 4,5 into GPU0. + # GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3] + 2, + 4, + torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]), + torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]), + id="simple", + ), + pytest.param( + # Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer + # Old mapping: + # GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated) + # GPU1 -> [4, 5, 6, 1, 2] + # New mapping reorders within GPUs and moves some experts across GPUs, + # while still including duplicates: + # GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming) + # GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated) + 2, + 5, + torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]), + torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]), + id="duplicates", + ), + pytest.param( + # Setup: 3 GPUs, 4 slots each (total 12 physical experts), 1 layer + # Old mapping: + # GPU0 -> [0, 1, 2, 3] + # GPU1 -> [0, 1, 2, 3] + # GPU2 -> [0, 1, 2, 3] + # New mapping decides to use one expert on 2 GPUs and shuffles + # experts on the third GPU, + # GPU0 new -> [0, 0, 0, 0] + # GPU1 new -> [0, 0, 0, 0] + # GPU2 new -> [1, 2, 3, 0] + 3, + 4, + torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]), + torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]), + id="skewed_expert", + ), + ], +) +def test_preserve_intragpu_slots( + num_ranks: int, + slots_per_gpu: int, + old_phy2log: torch.Tensor, + new_phy2log: torch.Tensor, +): + """Experts that stay on a GPU keep their old slots; incoming not lost.""" + phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log) + + post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots( + new_phy2log, phy_replicas_idx, num_ranks, old_phy2log + ) + + # Shapes preserved + assert post_phy2log.shape == new_phy2log.shape + assert post_phy_replicas_idx.shape == phy_replicas_idx.shape + + _validate_intragpu_rearrangement( + old_phy2log, + new_phy2log, + phy_replicas_idx, + post_phy2log, + post_phy_replicas_idx, + num_ranks, + slots_per_gpu, + ) diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index 781dfd44c..f732b05b1 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -286,15 +286,17 @@ def _test_async_transfer_layer_without_mtp_worker( device, old_indices, ) + old_indices_cpu = old_indices.cpu() + new_indices_cpu = new_indices.cpu() expert_buffer = [torch.empty_like(w) for w in expert_weights[0]] cuda_stream = torch.cuda.Stream(device=device) for layer_idx in range(num_layers): - is_unchanged, is_received_locally, experts_recv_loc = asyncio.run( + is_unchanged, is_received_locally, recv_metadata = asyncio.run( transfer_layer( - old_global_expert_indices=old_indices, - new_global_expert_indices=new_indices, + old_global_expert_indices=old_indices_cpu, + new_global_expert_indices=new_indices_cpu, expert_weights=expert_weights, expert_weights_buffer=expert_buffer, ep_group=ep_group, @@ -302,16 +304,15 @@ def _test_async_transfer_layer_without_mtp_worker( cuda_stream=cuda_stream, ) ) - cuda_stream.synchronize() move_from_buffer( expert_weights=expert_weights[layer_idx], - expert_weights_buffer=expert_buffer, + expert_weights_buffers=expert_buffer, is_unchanged=is_unchanged, is_received_locally=is_received_locally, - experts_recv_loc=experts_recv_loc, - new_indices=new_indices[layer_idx].tolist(), - ep_group=ep_group, + recv_metadata=recv_metadata, + new_indices=new_indices_cpu[layer_idx], + ep_rank=ep_rank, ) verify_expert_weights_after_shuffle( diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 9273ca66e..2e1ca74ed 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -69,6 +69,10 @@ class EPLBConfig: Log the balancedness each step of expert parallelism. This is turned off by default since it will cause communication overhead. """ + log_balancedness_interval: int = 1 + """ + Interval for logging the balancedness. + """ use_async: bool = False """ Whether to use non-blocking EPLB. @@ -77,6 +81,14 @@ class EPLBConfig: policy: EPLBPolicyOption = "default" """The policy type for expert parallel load balancing (EPLB).""" + @model_validator(mode="after") + def _validate_eplb_config(self) -> Self: + if self.use_async and self.policy != "default": + raise ValueError("Async EPLB is only supported with the default policy.") + if self.log_balancedness and self.log_balancedness_interval <= 0: + raise ValueError("log_balancedness_interval must be greater than 0.") + return self + @config @dataclass diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index e4b4fc92e..9d7366996 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -89,7 +89,7 @@ async def transfer_run_periodically( ( model_state.is_unchanged, model_state.is_received_locally, - model_state.experts_recv_loc, + model_state.recv_metadata, ) = await transfer_layer( old_global_expert_indices=model_state.physical_to_logical_map, new_global_expert_indices=model_state.new_physical_to_logical_map, diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 7826b1286..a482c6f55 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -27,10 +27,10 @@ physical experts. """ import threading -import time from collections.abc import Sequence from dataclasses import dataclass +import numpy as np import torch from torch.distributed import ProcessGroup, all_reduce @@ -46,7 +46,11 @@ from vllm.model_executor.models.interfaces import MixtureOfExperts from .async_worker import start_async_worker from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy -from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace +from .rebalance_execute import ( + RecvMetadata, + move_from_buffer, + rearrange_expert_weights_inplace, +) logger = init_logger(__name__) @@ -164,20 +168,19 @@ class EplbModelState: """ Whether the async EPLB needs to poll peers for buffer readiness. """ - is_unchanged: list[bool] + is_unchanged: np.ndarray """ intermediate variable between `move_to_buffer` and `move_to_workspace`. The size is same as the num of physical experts in the current layer. """ - is_received_locally: list[bool] + is_received_locally: np.ndarray """ intermediate variable between `move_to_buffer` and `move_to_workspace`. The size is same as the num of physical experts in the current layer. """ - experts_recv_loc: dict[int, int] + recv_metadata: RecvMetadata """ intermediate variable between `move_to_buffer` and `move_to_workspace`. - The size is same as the num of physical experts in the current layer. """ is_async_enabled: bool """ @@ -507,9 +510,14 @@ class EplbState: layer_to_transfer=0, rebalanced=False, pending_global_ready_check=False, - is_unchanged=[], - is_received_locally=[], - experts_recv_loc={}, + is_unchanged=np.array([]), + is_received_locally=np.array([]), + recv_metadata=RecvMetadata( + recv_primary_mask=np.array([]), + recv_count=0, + recv_expert_ids=np.array([]), + recv_dst_rows=np.array([]), + ), is_async_enabled=self.is_async, cuda_device_index=self.cuda_device_index, new_physical_to_logical_map=new_physical_to_logical_map, @@ -553,7 +561,12 @@ class EplbState: for eplb_model_state in self.model_states.values(): eplb_model_state.expert_load_pass.zero_() - if log_stats: + if ( + log_stats + and self.expert_rearrangement_step + % self.parallel_config.eplb_config.log_balancedness_interval + == 0 + ): # Sync the expert load pass for each model (main and drafter). # expert_load_pass: (num_moe_layers, num_physical_experts) expert_load_pass_list = self._sync_load_pass() @@ -586,12 +599,15 @@ class EplbState: if ep_group.rank() == 0: logger.info( "EPLB step: %d for model %s: avg_tokens=%.2f, " - "max_tokens=%d, balancedness=%.4f", + "max_tokens=%d, balancedness=%.4f, " + "steps until the next rearrangement: %d", self.expert_rearrangement_step, eplb_model_state.model_name, avg_tokens, max_tokens, balancedness, + self.expert_rearrangement_step_interval + - self.expert_rearrangement_step, ) # Update the expert load sliding window @@ -684,11 +700,14 @@ class EplbState: ep_group = get_ep_group().device_group ep_rank = ep_group.rank() - time_start = None + start_event = None + end_event = None is_main_rank = ep_rank == 0 if is_main_rank: - torch.cuda.synchronize() - time_start = time.perf_counter() + if not self.is_async or is_profile: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() logger.info( "Rearranging experts %s %s...", "(async mode)" if self.is_async else "sync mode", @@ -800,6 +819,7 @@ class EplbState: num_groups, num_nodes, num_gpus, + eplb_model_state.physical_to_logical_map, ) if not eplb_model_state.is_async_enabled or is_profile: @@ -848,17 +868,17 @@ class EplbState: new_logical_replica_count ) if is_main_rank: - assert time_start is not None - torch.cuda.synchronize() - time_end = time.perf_counter() + assert start_event is not None + assert end_event is not None + end_event.record() + end_event.synchronize() + gpu_elapsed = start_event.elapsed_time(end_event) / 1000.0 logger.info( - "Rearranged experts%sin %.2f seconds.", + "Rearranged experts %s in %.2f s.", " (profile) " if is_profile else " ", - time_end - time_start, + gpu_elapsed, ) else: - device = eplb_model_state.physical_to_logical_map.device - new_physical = new_physical_to_logical_map.to(device) max_slots = eplb_model_state.logical_to_physical_map.shape[-1] padded_logical = torch.nn.functional.pad( new_logical_to_physical_map, @@ -869,7 +889,10 @@ class EplbState: eplb_model_state.logical_replica_count.device ) - eplb_model_state.new_physical_to_logical_map = new_physical + # Move map to cpu in advance + eplb_model_state.new_physical_to_logical_map = ( + new_physical_to_logical_map.cpu() + ) eplb_model_state.new_logical_to_physical_map = padded_logical eplb_model_state.new_logical_replica_count = new_replica @@ -968,25 +991,30 @@ class EplbState: stream = torch.cuda.current_stream(device=device_index) stream.wait_event(model_state.buffer_ready_event) model_state.buffer_ready_event = None + expert_weights = model_state.model.expert_weights[ + model_state.layer_to_transfer + ] + expert_weights_buffer = model_state.expert_buffer + new_indices = ( + model_state.new_physical_to_logical_map[model_state.layer_to_transfer] + .cpu() + .numpy() + ) move_from_buffer( - expert_weights=model_state.model.expert_weights[ - model_state.layer_to_transfer - ], - expert_weights_buffer=model_state.expert_buffer, + expert_weights=expert_weights, + expert_weights_buffers=expert_weights_buffer, is_unchanged=model_state.is_unchanged, is_received_locally=model_state.is_received_locally, - experts_recv_loc=model_state.experts_recv_loc, - new_indices=model_state.new_physical_to_logical_map[ - model_state.layer_to_transfer - ].tolist(), - ep_group=ep_group, + recv_metadata=model_state.recv_metadata, + new_indices=new_indices, + ep_rank=ep_group.rank(), ) transferred_layer = model_state.layer_to_transfer self._update_layer_mapping_from_new(model_state, transferred_layer) # After the main thread consumes, advance layer_to_transfer model_state.layer_to_transfer += 1 model_state.ep_buffer_ready = 0 - logger.info( + logger.debug( "model %s successfully move_to_workspace layer %d", model_state.model_name, transferred_layer, diff --git a/vllm/distributed/eplb/policy/abstract.py b/vllm/distributed/eplb/policy/abstract.py index 40ed621c8..f4435f11b 100644 --- a/vllm/distributed/eplb/policy/abstract.py +++ b/vllm/distributed/eplb/policy/abstract.py @@ -16,6 +16,7 @@ class AbstractEplbPolicy(ABC): num_groups: int, num_nodes: int, num_ranks: int, + old_global_expert_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Entry point for expert-parallelism load balancer. @@ -28,7 +29,9 @@ class AbstractEplbPolicy(ABC): num_groups: number of expert groups num_nodes: number of server nodes num_ranks: number of ranks, must be a multiple of `num_nodes` - + old_global_expert_indices: [layers, num_logical_experts], the old global + expert indices. Used to avoid unnecessary weight copying + for experts moving within one rank. Returns: physical_to_logical_map: [layers, num_replicas], the expert index of each replica diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index 6127ec703..ebbbc6db1 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): Returns: phy2log: [X, num_phy], logical expert id of each physical expert - rank: [X, num_phy], the replica rank + replica_idx: [X, num_phy], the index of the replica for each logical expert logcnt: [X, num_log], number of replicas for each logical expert """ n, num_log = weight.shape @@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy): assert num_redundant >= 0 device = weight.device phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) - rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device) logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) arangen = torch.arange(n, dtype=torch.int64, device=device) for i in range(num_log, num_phy): redundant_indices = (weight / logcnt).max(dim=-1).indices phy2log[:, i] = redundant_indices - rank[:, i] = logcnt[arangen, redundant_indices] + replica_idx[:, i] = logcnt[arangen, redundant_indices] logcnt[arangen, redundant_indices] += 1 - return phy2log, rank, logcnt + return phy2log, replica_idx, logcnt @classmethod def rebalance_experts_hierarchical( @@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): Returns: phy2log: [layers, num_replicas], the expert index of each replica - log2phy: [layers, num_logical_experts, X], + pphy_replicas_idx: [layers, num_logical_experts, X], the replica indices for each expert logcnt: [layers, num_logical_experts], number of physical replicas for each logical expert @@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): tokens_per_mlog = weight.gather(-1, mlog2log).view( -1, num_logical_experts // num_nodes ) - phy2mlog, phyrank, mlogcnt = cls.replicate_experts( + phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts( tokens_per_mlog, num_physical_experts // num_nodes ) @@ -203,9 +203,109 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ).view(1, -1, 1) ).flatten(-2) pphy2log = mlog2log.gather(-1, pphy2mlog) - pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) - return pphy2log, pphyrank, logcnt + return pphy2log, pphy_replicas_idx, logcnt + + @classmethod + def preserve_intragpu_slots( + cls, + phy2log: torch.Tensor, + phy_replicas_idx: torch.Tensor, + num_ranks: int, + old_global_expert_indices: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Reorder the new mapping per GPU so that experts that remain on the same GPU + keep their previous slot positions when possible. Incoming experts to that GPU + fill any remaining available slots. This is applied only when the number of GPUs + is unchanged and the slots per GPU remain the same between + the old and new mappings. + """ + device = phy2log.device + num_phy_experts = phy2log.shape[1] + if num_ranks <= 0 or num_phy_experts % num_ranks != 0: + return phy2log, phy_replicas_idx + + # Move to CPU and convert to NumPy for processing + new_phy2log_np = phy2log.cpu().numpy() + replicas_idx_np = phy_replicas_idx.cpu().numpy() + old_phy2log_np = old_global_expert_indices.cpu().numpy() + + slots_per_gpu = num_phy_experts // num_ranks + num_layers = new_phy2log_np.shape[0] + + post_phy2log_np = new_phy2log_np.copy() + post_phy_replicas_idx_np = replicas_idx_np.copy() + + for gpu_idx in range(num_ranks): + start = gpu_idx * slots_per_gpu + end = start + slots_per_gpu + # Experts across all layers for this GPU + old_local = old_phy2log_np[:, start:end] # [layers, slots] + new_local = new_phy2log_np[:, start:end] # [layers, slots] + new_ridx = replicas_idx_np[:, start:end] # [layers, slots] + + used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool) + preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool) + + # First pass: preserve same-logical experts in their previous slots + for slot_idx in range(slots_per_gpu): + # matches: [layers, slots], True where new local experts have + # the same logical value as the old from 'slot_idx' and not checked yet + matches = (new_local == old_local[:, slot_idx][:, None]) & ( + ~used_new_indices + ) + has_any = matches.any(axis=1) + if np.any(has_any): + first_idx = np.argmax(matches, axis=1) + layer_indices = np.nonzero(has_any)[0] + matched_new_positions = first_idx[layer_indices] + post_phy2log_np[layer_indices, start + slot_idx] = new_local[ + layer_indices, matched_new_positions + ] + post_phy_replicas_idx_np[layer_indices, start + slot_idx] = ( + new_ridx[layer_indices, matched_new_positions] + ) + used_new_indices[layer_indices, matched_new_positions] = True + preserved_positions[layer_indices, slot_idx] = True + + # Second pass: fill remaining slots with remaining new experts + remaining_mask = ~used_new_indices # [layers, slots] + fill_mask = ~preserved_positions # [layers, slots] + if remaining_mask.any() and fill_mask.any(): + idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1)) + # Sentinel value for unavailable positions. + large = slots_per_gpu + 1 + # Priorities: keep original index for available spots, set sentinel + # for unavailable; lower is earlier. + remaining_priority = np.where(remaining_mask, idx_base, large) + fill_priority = np.where(fill_mask, idx_base, large) + # Sort to get ordered indices of available src/dst positions per layer. + remaining_indices = np.argsort(remaining_priority, axis=1) + fill_indices = np.argsort(fill_priority, axis=1) + # Fill count per layer (cannot exceed either side). + remaining_counts = remaining_mask.sum(axis=1) + fill_counts = fill_mask.sum(axis=1) + take_counts = np.minimum(remaining_counts, fill_counts) + # Assign remaining new experts to remaining slots per layer. + for layer_idx in range(num_layers): + k = int(take_counts[layer_idx]) + if k <= 0: + continue + src_pos = remaining_indices[layer_idx, :k] + dst_pos = fill_indices[layer_idx, :k] + post_phy2log_np[layer_idx, start + dst_pos] = new_local[ + layer_idx, src_pos + ] + post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[ + layer_idx, src_pos + ] + + # Convert back to torch and move to original device + post_phy2log = torch.from_numpy(post_phy2log_np).to(device) + post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device) + return post_phy2log, post_phy_replicas_idx @classmethod def rebalance_experts( @@ -215,6 +315,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): num_groups: int, num_nodes: int, num_ranks: int, + old_global_expert_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Entry point for expert-parallelism load balancer. @@ -228,7 +329,9 @@ class DefaultEplbPolicy(AbstractEplbPolicy): num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster num_ranks: number of ranks, must be a multiple of `num_nodes` - + old_global_expert_indices: [layers, num_logical_experts], the old global + expert indices. Used to avoid unnecessary weight copying + for experts moving within one rank. Returns: phy2log: [layers, num_replicas], the expert index of each replica @@ -241,14 +344,23 @@ class DefaultEplbPolicy(AbstractEplbPolicy): weight = weight.float() if num_groups % num_nodes == 0: # use hierarchical load-balance policy - phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( weight, num_replicas, num_groups, num_nodes, num_ranks ) else: # use global load-balance policy - phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( weight, num_replicas, 1, 1, num_ranks ) + # Optional postprocessing to preserve slots for experts moving + # within the same GPU + # Only apply when the number of GPUs and slots per GPU remain unchanged. + # Helps to avoid unnecessary weight copying when experts move + # within the same GPU. + if old_global_expert_indices is not None: + phy2log, phy_replicas_idx = cls.preserve_intragpu_slots( + phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices + ) num_redundant_experts = num_replicas - num_logical_experts maxlogcnt = num_redundant_experts + 1 log2phy: torch.Tensor = torch.full( @@ -259,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ) log2phy.view(num_layers, -1).scatter_( -1, - phy2log * maxlogcnt + phyrank, + phy2log * maxlogcnt + phy_replicas_idx, torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( num_layers, -1 ), diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 55856d940..b7b6c11b2 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -6,9 +6,10 @@ The actual execution of the rearrangement. This involves the exchange of expert weights between GPUs. """ -from collections.abc import Iterable, MutableSequence, Sequence -from functools import partial +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +import numpy as np import torch from torch.distributed import ( P2POp, @@ -18,214 +19,318 @@ from torch.distributed import ( get_global_rank, ) +from vllm.logger import init_logger -def idx_local_to_global( - local_idx: int, - local_cnt: int, - ep_rank: int, -) -> int: - """ - Convert a local expert index to a global expert index. - """ - return ep_rank * local_cnt + local_idx +logger = init_logger(__name__) -def idx_global_to_local( - global_idx: int, - local_cnt: int, - ep_rank: int, -) -> int: - """ - Convert a global expert index to a local expert index. - """ - return global_idx - ep_rank * local_cnt +@dataclass +class RecvMetadata: + """Metadata describing remote receives during EPLB rebalancing.""" + + recv_primary_mask: np.ndarray + """Mask of (num_local_experts,) indicating primary experts received.""" + recv_count: int + """Number of received experts for the layer.""" + recv_expert_ids: np.ndarray + """Expert ids (num_local_experts,) of remote primary experts.""" + recv_dst_rows: np.ndarray + """Target expert indices (num_local_experts,) in local tensors to send.""" -def global_idx_to_rank( - global_idx: int, - local_cnt: int, -) -> int: - """ - Convert a global expert index to a rank index. - """ - return global_idx // local_cnt +# Type alias for the result of move_to_buffer or transfer_layer +MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata] -def get_ep_ranks_with_expert( - idx: int, +def get_ep_ranks_with_experts_batch( + expert_ids: np.ndarray, num_local_experts: int, - old_indices: Sequence[int], - new_indices: Sequence[int], -) -> tuple[MutableSequence[int], MutableSequence[int]]: + old_indices: np.ndarray, + new_indices: np.ndarray, +) -> tuple[dict[int, list[int]], dict[int, list[int]]]: """ Get the ranks of the experts that need to be exchanged. Args: - idx: The index of the expert. + expert_ids: 1D array of expert indices to query. num_local_experts: The number of local experts. old_indices: The old indices of the experts. new_indices: The new indices of the experts. Returns: - A tuple of two lists: - - The ranks of the experts that need to be sent. - - The ranks of the experts that need to be received. + A tuple of two dictionaries mapping expert_id to: + - ranks_to_send: The ranks that have this expert and need to send. + - ranks_to_recv: The ranks that need to receive this expert. """ - global2rank = partial( - global_idx_to_rank, - local_cnt=num_local_experts, - ) + ranks_to_send_map: dict[int, list[int]] = {} + ranks_to_recv_map: dict[int, list[int]] = {} - ranks_to_send: list[int] = [] - ranks_to_recv: list[int] = [] + # Fast path: if no experts, return empty dicts + if expert_ids.size == 0: + return ranks_to_send_map, ranks_to_recv_map - for i, e in enumerate(old_indices): - if e == idx: - rank = global2rank(i) - if not ranks_to_send or ranks_to_send[-1] != rank: - ranks_to_send.append(rank) + unique_experts = np.unique(expert_ids) + num_positions = len(old_indices) + position_indices = np.arange(num_positions, dtype=np.int32) - for i, e in enumerate(new_indices): - if e == idx: - rank = global2rank(i) - if not ranks_to_recv or ranks_to_recv[-1] != rank: - ranks_to_recv.append(rank) + # Vectorized approach: find all positions matching any query expert in one pass + # Use np.isin to get boolean masks for all relevant positions at once + old_relevant_mask = np.isin(old_indices, unique_experts) + new_relevant_mask = np.isin(new_indices, unique_experts) - # Remove those ranks that can get this expert locally. - ranks_to_send_set = set(ranks_to_send) - ranks_to_recv_actual = [ - rank for rank in ranks_to_recv if rank not in ranks_to_send_set - ] + # Process old_indices (send ranks) + if np.any(old_relevant_mask): + old_relevant_positions = position_indices[old_relevant_mask] + old_relevant_experts = old_indices[old_relevant_mask] + old_relevant_ranks = old_relevant_positions // num_local_experts - return ranks_to_send, ranks_to_recv_actual + # Sort by expert first, then by position (to maintain first-appearance order) + sort_order = np.lexsort((old_relevant_positions, old_relevant_experts)) + sorted_experts = old_relevant_experts[sort_order] + sorted_ranks = old_relevant_ranks[sort_order] + + # Find boundaries where expert changes + expert_boundaries = np.concatenate( + [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]] + ) + + # For each expert, extract unique ranks in order of first appearance + for i in range(len(expert_boundaries) - 1): + start, end = expert_boundaries[i], expert_boundaries[i + 1] + expert = int(sorted_experts[start]) + expert_ranks = sorted_ranks[start:end] + + # Get unique ranks preserving order + _, unique_idx = np.unique(expert_ranks, return_index=True) + unique_ranks = expert_ranks[np.sort(unique_idx)] + ranks_to_send_map[expert] = unique_ranks.tolist() + + # Process new_indices (recv ranks) + if np.any(new_relevant_mask): + new_relevant_positions = position_indices[new_relevant_mask] + new_relevant_experts = new_indices[new_relevant_mask] + new_relevant_ranks = new_relevant_positions // num_local_experts + + # Sort by expert first, then by position + sort_order = np.lexsort((new_relevant_positions, new_relevant_experts)) + sorted_experts = new_relevant_experts[sort_order] + sorted_ranks = new_relevant_ranks[sort_order] + + # Find boundaries where expert changes + expert_boundaries = np.concatenate( + [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]] + ) + + # For each expert, extract unique ranks and exclude local copies + for i in range(len(expert_boundaries) - 1): + start, end = expert_boundaries[i], expert_boundaries[i + 1] + expert = int(sorted_experts[start]) + expert_ranks = sorted_ranks[start:end] + + # Get unique ranks preserving order + _, unique_idx = np.unique(expert_ranks, return_index=True) + unique_ranks = expert_ranks[np.sort(unique_idx)] + + # Remove ranks that have local copies (in send map) + send_ranks_set = set(ranks_to_send_map.get(expert, [])) + recv_ranks_actual = [ + int(r) for r in unique_ranks if r not in send_ranks_set + ] + ranks_to_recv_map[expert] = recv_ranks_actual + + # Handle experts that only appear in old (send only) or new (recv only) + for expert in unique_experts: + expert = int(expert) + if expert not in ranks_to_send_map: + ranks_to_send_map[expert] = [] + if expert not in ranks_to_recv_map: + ranks_to_recv_map[expert] = [] + + return ranks_to_send_map, ranks_to_recv_map def move_to_buffer( num_local_experts: int, - old_indices: Sequence[int], - new_indices: Sequence[int], + old_indices: np.ndarray, + new_indices: np.ndarray, expert_weights: Iterable[torch.Tensor], - expert_weights_buffer: Sequence[torch.Tensor], + expert_weights_buffers: Sequence[torch.Tensor], cuda_stream: torch.cuda.Stream | None, ep_group: ProcessGroup, -) -> tuple[list[bool], list[bool], dict[int, int]]: +) -> MoveToBufferResult: """ - Perform expert weights rearrangement of one layer. + Rearranges expert weights during EPLB rebalancing. + + Args: + num_local_experts: Number of local experts. + old_indices: (num_experts_total,) ndarray of current (old) + global-to-local expert assignments. + new_indices: (num_experts_total,) ndarray of desired (new) + global-to-local assignments after rebalance. + expert_weights: Original expert weights for the layer. + expert_weights_buffers: Intermediate buffers (one per tensor). + cuda_stream: CUDA stream for async copies (can be None for sync mode). + ep_group: Distributed process group for expert parallel comms. + + Returns: + is_unchanged (np.ndarray): (num_local_experts,), True where an expert row + is unchanged after rebalance. + is_received_locally (np.ndarray): (num_local_experts,), True where a row + can be updated from local data. + RecvMetadata: Metadata needed for completing remote weight transfers. """ + assert old_indices.shape == new_indices.shape ep_rank = ep_group.rank() - local2global = partial( - idx_local_to_global, - local_cnt=num_local_experts, - ep_rank=ep_rank, + + recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_) + send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64) + send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32) + recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64) + recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32) + + base = ep_rank * num_local_experts + local_rows = np.arange(num_local_experts, dtype=np.int32) + local_global = base + local_rows + + old_local_expert_ids = old_indices[local_global] + new_local_expert_ids = new_indices[local_global] + + # Unchanged mask + is_unchanged = old_local_expert_ids == new_local_expert_ids + + # Local receive eligibility + new_valid = new_local_expert_ids != -1 + can_recv_local = np.isin( + new_local_expert_ids, old_local_expert_ids, assume_unique=False + ) + is_received_locally = np.logical_or( + is_unchanged, np.logical_and(new_valid, can_recv_local) ) - # 0. Do nothing for experts that did not change. - is_unchanged = [ - old_indices[local2global(i)] == new_indices[local2global(i)] - for i in range(num_local_experts) - ] + # Send map: first src row per unique expert present locally in old mapping + send_count = 0 + valid_old = old_local_expert_ids != -1 + if np.any(valid_old): + uniq_experts, first_idx = np.unique( + old_local_expert_ids[valid_old], return_index=True + ) + filtered_rows = local_rows[valid_old] + src_rows = filtered_rows[first_idx] + send_count = int(uniq_experts.shape[0]) + send_expert_ids[:send_count] = uniq_experts + send_src_rows[:send_count] = src_rows - # 1. Perform weight copy inside the local rank. - is_received_locally = is_unchanged[:] - for src in range(num_local_experts): - src_global = local2global(src) - for dst in range(num_local_experts): - dst_global = local2global(dst) - if is_received_locally[dst]: - continue - if old_indices[src_global] == -1 or new_indices[dst_global] == -1: - continue - if old_indices[src_global] == new_indices[dst_global]: - is_received_locally[dst] = True - for weight, buffer in zip(expert_weights, expert_weights_buffer): - with torch.cuda.stream(cuda_stream): - buffer[dst].copy_(weight[src], non_blocking=True) + # Recv map: primary dst per unique expert needed remotely + recv_count = 0 + need_recv_mask = np.logical_and(~is_received_locally, new_valid) + if np.any(need_recv_mask): + desired_experts = new_local_expert_ids[need_recv_mask] + desired_dsts = local_rows[need_recv_mask] + uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True) + dst_rows = desired_dsts[uniq_indices] + recv_count = int(uniq_recv_experts.shape[0]) + recv_expert_ids[:recv_count] = uniq_recv_experts + recv_dst_rows[:recv_count] = dst_rows + recv_primary_mask[dst_rows] = True + + eligible_local_buffer_mask = np.logical_and(~is_unchanged, is_received_locally) + + # 1. Local moves into tmp buffers + if bool(eligible_local_buffer_mask.any()) and send_count > 0: + dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist() + expert_to_src_map = dict( + zip(send_expert_ids[:send_count], send_src_rows[:send_count]) + ) + for dst in dest_indices: + expert = new_local_expert_ids[dst] + src_local = expert_to_src_map.get(expert, -1) + if src_local != -1: + for w, b in zip(expert_weights, expert_weights_buffers): + b[dst].copy_(w[src_local], non_blocking=True) p2p_ops: list[P2POp] = [] - # 2. Initiate sending of weights. - experts_send_loc: dict[int, int] = {} - for src in range(num_local_experts): - expert = old_indices[local2global(src)] - if expert == -1: - continue - if expert in experts_send_loc: - continue - experts_send_loc[expert] = src + # Pre-compute global ranks mapping + ep_size = ep_group.size() + rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)} - # We need to sort here to match send/recv - for expert, src in sorted(experts_send_loc.items()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, + # 2. Post sends + if send_count > 0: + experts = send_expert_ids[:send_count] + srcs = send_src_rows[:send_count] + order = np.argsort(experts, kind="stable") + experts = experts[order] + srcs = srcs[order] + + send_map, recv_map = get_ep_ranks_with_experts_batch( + experts, num_local_experts, old_indices, new_indices, ) - # Calculate the ranks to send by this rank - num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) - sender_pos = ranks_to_send.index(ep_rank) - recv_begin = sender_pos * num_dst_per_sender - recv_end = recv_begin + num_dst_per_sender - recv_ranks = ranks_to_recv[recv_begin:recv_end] + for expert, src in zip(experts.tolist(), srcs.tolist()): + ranks_to_send = send_map[expert] + ranks_to_recv = recv_map[expert] + if not ranks_to_send or not ranks_to_recv: + continue + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + sender_pos = ranks_to_send.index(ep_rank) + recv_begin = sender_pos * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + recv_ranks = ranks_to_recv[recv_begin:recv_end] + remainder_start = len(ranks_to_send) * num_dst_per_sender + recver_pos = remainder_start + sender_pos + if recver_pos < len(ranks_to_recv): + recv_ranks.append(ranks_to_recv[recver_pos]) + for dst in recv_ranks: + dst_global = rank_to_global[dst] + p2p_ops += [ + P2POp( + torch.distributed.isend, + w[src], + dst_global, + ) + for w in expert_weights + ] - # Tackle remainders - remainder_start = len(ranks_to_send) * num_dst_per_sender - recver_pos = remainder_start + sender_pos - if recver_pos < len(ranks_to_recv): - recv_ranks.append(ranks_to_recv[recver_pos]) + # 3. Post recvs + if recv_count > 0: + experts = recv_expert_ids[:recv_count] + dsts = recv_dst_rows[:recv_count] + order = np.argsort(experts, kind="stable") + experts = experts[order] + dsts = dsts[order] - for dst in recv_ranks: - dst_global = get_global_rank(ep_group, dst) + send_map, recv_map = get_ep_ranks_with_experts_batch( + experts, + num_local_experts, + old_indices, + new_indices, + ) + + for expert, dst in zip(experts.tolist(), dsts.tolist()): + ranks_to_send = send_map[expert] + ranks_to_recv = recv_map[expert] + if not ranks_to_send or not ranks_to_recv: + continue + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + recver_pos = ranks_to_recv.index(ep_rank) + remainder_start = len(ranks_to_send) * num_dst_per_sender + if recver_pos < remainder_start: + src = ranks_to_send[recver_pos // num_dst_per_sender] + else: + src = ranks_to_send[recver_pos - remainder_start] + src_global = rank_to_global[src] p2p_ops += [ P2POp( - torch.distributed.isend, - weight[src], - dst_global, + torch.distributed.irecv, + b[dst], + src_global, ) - for weight in expert_weights + for b in expert_weights_buffers ] - # 3. Initiate receiving of weights. - experts_recv_loc: dict[int, int] = {} - for dst in range(num_local_experts): - if is_received_locally[dst]: - continue - expert = new_indices[local2global(dst)] - if expert == -1: - continue - if expert in experts_recv_loc: - continue - experts_recv_loc[expert] = dst - - # We need to sort here to match send/recv - for expert, dst in sorted(experts_recv_loc.items()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, - num_local_experts, - old_indices, - new_indices, - ) - - # Calculate the rank to recv by this rank - num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) - recver_pos = ranks_to_recv.index(ep_rank) - remainder_start = len(ranks_to_send) * num_dst_per_sender - if recver_pos < remainder_start: - src = ranks_to_send[recver_pos // num_dst_per_sender] - else: - src = ranks_to_send[recver_pos - remainder_start] - - src_global = get_global_rank(ep_group, src) - p2p_ops += [ - P2POp( - torch.distributed.irecv, - weight[dst], - src_global, - ) - for weight in expert_weights_buffer - ] - # 4. Execute the P2P operations. The real communication happens here. if p2p_ops and cuda_stream is not None: with torch.cuda.stream(cuda_stream): @@ -237,38 +342,95 @@ def move_to_buffer( for req in reqs: req.wait() # wait for the communication to finish - return is_unchanged, is_received_locally, experts_recv_loc + return ( + is_unchanged, + is_received_locally, + RecvMetadata( + recv_primary_mask=recv_primary_mask, + recv_count=recv_count, + recv_expert_ids=recv_expert_ids, + recv_dst_rows=recv_dst_rows, + ), + ) def move_from_buffer( expert_weights: Iterable[torch.Tensor], - expert_weights_buffer: list[torch.Tensor], - is_unchanged: list[bool], - is_received_locally: list[bool], - experts_recv_loc: dict[int, int], - new_indices: Sequence[int], - ep_group: ProcessGroup, + expert_weights_buffers: list[torch.Tensor], + is_unchanged: np.ndarray, + is_received_locally: np.ndarray, + recv_metadata: RecvMetadata, + new_indices: np.ndarray, + ep_rank: int, ) -> None: - ep_rank = ep_group.rank() - num_local_experts = len(is_unchanged) + """ + Copies expert weights from communication buffers back to the target weight tensors + after EPLB rebalancing. - local2global = partial( - idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank + Args: + expert_weights: List of the actual MoE layer weights used in the execution. + expert_weights_buffers: Intermediate buffers containing the experts weights + after the transfer is completed. + is_unchanged: (num_local_experts,), True where an expert row is unchanged. + is_received_locally: (num_local_experts,), True where a row is updated locally. + recv_metadata: RecvMetadata containing remote receive metadata. + new_indices: (num_experts_total,) mapping from local rows to desired + (possibly global) expert id, after rebalance. + ep_rank: Rank of the process in the expert parallel group. + """ + recv_primary_mask = recv_metadata.recv_primary_mask + recv_count = recv_metadata.recv_count + recv_expert_ids = recv_metadata.recv_expert_ids + recv_dst_rows = recv_metadata.recv_dst_rows + num_local_experts = is_unchanged.shape[0] + + # Mask for rows to copy back from buffers: + # copy if locally received OR remote primary recv + copy_mask = np.logical_or(is_received_locally, recv_primary_mask) + dest_mask_np = np.logical_and(~is_unchanged, copy_mask) + if bool(dest_mask_np.any()): + dest_indices = np.nonzero(dest_mask_np)[0].tolist() + for dst in dest_indices: + for w, b in zip(expert_weights, expert_weights_buffers): + w[dst].copy_(b[dst], non_blocking=True) + + if recv_count == 0: + return + + # Duplicate remote received rows to non-primary duplicate dsts + base = ep_rank * num_local_experts + local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)] + duplicate_mask = np.logical_and( + np.logical_and(~is_unchanged, ~is_received_locally), + np.logical_and(~recv_primary_mask, local_experts != -1), ) + # All received experts are unique in the destination, so no need to copy duplicates + if not bool(duplicate_mask.any()): + return - for dst in range(num_local_experts): - if is_unchanged[dst]: - continue - if is_received_locally[dst]: - for weight, buffer in zip(expert_weights, expert_weights_buffer): - weight[dst].copy_(buffer[dst], non_blocking=True) - else: - expert = new_indices[local2global(dst)] - if expert == -1: - continue - src = experts_recv_loc[expert] - for weight, buffer in zip(expert_weights, expert_weights_buffer): - weight[dst].copy_(buffer[src], non_blocking=True) + dup_dst_rows = np.nonzero(duplicate_mask)[0] + dup_experts = local_experts[dup_dst_rows] + + prim_experts = recv_expert_ids[:recv_count] + prim_dsts = recv_dst_rows[:recv_count] + order = np.argsort(prim_experts, kind="stable") + prim_experts_sorted = prim_experts[order] + prim_dsts_sorted = prim_dsts[order] + pos = np.searchsorted(prim_experts_sorted, dup_experts) + valid = np.logical_and( + pos < prim_experts_sorted.shape[0], + prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)] + == dup_experts, + ) + if not bool(valid.any()): + return + + matched_dst_rows = dup_dst_rows[valid] + matched_src_rows = prim_dsts_sorted[pos[valid]] + + for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()): + for w in expert_weights: + w[dst].copy_(w[src], non_blocking=True) async def transfer_layer( @@ -281,7 +443,7 @@ async def transfer_layer( layer: int = 0, cuda_stream: torch.cuda.Stream | None = None, rank_mapping: dict[int, int] | None = None, -) -> tuple[list[bool], list[bool], dict[int, int]]: +) -> MoveToBufferResult: """ Rearranges the expert weights in place according to the new expert indices. @@ -299,6 +461,13 @@ async def transfer_layer( is_profile (bool): If `True`, do not perform any actual weight copy. This is used during profile run, where we only perform dummy communications to reserve enough memory for the buffers. + + Returns: + is_unchanged (np.ndarray): (1, num_local_experts), True where expert + is left unchanged. + is_received_locally (np.ndarray): (1, num_local_experts), True where expert + can be received locally. + RecvMetadata: Metadata needed for completing remote weight transfers. """ ep_size = ep_group.size() if rank_mapping is not None: @@ -323,16 +492,19 @@ async def transfer_layer( assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) assert num_physical_experts == ep_size * num_local_physical_experts - is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( + old_global_expert_indices_np = old_global_expert_indices.cpu().numpy() + new_global_expert_indices_np = new_global_expert_indices.cpu().numpy() + + is_unchanged, is_received_locally, recv_metadata = move_to_buffer( num_local_experts=num_local_physical_experts, - old_indices=old_global_expert_indices[layer].tolist(), - new_indices=new_global_expert_indices[layer].tolist(), + old_indices=old_global_expert_indices_np[layer], + new_indices=new_global_expert_indices_np[layer], expert_weights=expert_weights[layer], - expert_weights_buffer=expert_weights_buffer, + expert_weights_buffers=expert_weights_buffer, cuda_stream=cuda_stream, ep_group=ep_group, ) - return is_unchanged, is_received_locally, experts_recv_loc + return is_unchanged, is_received_locally, recv_metadata def rearrange_expert_weights_inplace( @@ -388,19 +560,17 @@ def rearrange_expert_weights_inplace( ep_size = ep_group.size() assert num_physical_experts == ep_size * num_local_physical_experts - # A buffer to hold the expert weights in one layer during the exchange. + first_layer_weights = list(expert_weights[0]) + # Buffers to hold the expert weights during the exchange. # NOTE: Currently we assume the same weights across different layers # have the same shape. - expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]] - + weights_buffer: list[torch.Tensor] = [ + torch.empty_like(w) for w in first_layer_weights + ] if is_profile: - # Maximum send size is to send all local experts to all ranks, - # So we use a dummy `all_gather` to reserve enough communication buffer - for weight, buffer in zip(expert_weights[0], expert_weights_buffer): - # A `/dev/null`-like buffer to avoid real memory allocation + # Reserve communication buffers via a minimal dummy all_gather on first layer + for weight, buffer in zip(expert_weights[0], weights_buffer): dummy_recv_buffer = [buffer for _ in range(ep_size)] - # NOTE(bowen): Needed this barrier to avoid OOM during actual - # execution. I'm not very sure why this is needed torch.distributed.barrier() all_gather( dummy_recv_buffer, @@ -409,32 +579,32 @@ def rearrange_expert_weights_inplace( ) return - old_global_expert_indices_cpu = old_global_expert_indices.cpu() - new_global_expert_indices_cpu = new_global_expert_indices.cpu() - # NOTE(bowen): We need this synchronize to run, but I don't know why. # If you figure out the reason, please let me know -- thank you! torch.cuda.synchronize() - for layer in range(num_moe_layers): - is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( + old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy() + new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy() + + for layer_idx in range(num_moe_layers): + is_unchanged, is_received_locally, recv_metadata = move_to_buffer( num_local_experts=num_local_physical_experts, - old_indices=old_global_expert_indices_cpu[layer].tolist(), - new_indices=new_global_expert_indices_cpu[layer].tolist(), - expert_weights=expert_weights[layer], - expert_weights_buffer=expert_weights_buffer, + old_indices=old_global_expert_indices_cpu[layer_idx], + new_indices=new_global_expert_indices_cpu[layer_idx], + expert_weights=expert_weights[layer_idx], + expert_weights_buffers=weights_buffer, cuda_stream=None, ep_group=ep_group, ) move_from_buffer( - expert_weights=expert_weights[layer], - expert_weights_buffer=expert_weights_buffer, + expert_weights=expert_weights[layer_idx], + expert_weights_buffers=weights_buffer, is_unchanged=is_unchanged, is_received_locally=is_received_locally, - experts_recv_loc=experts_recv_loc, - new_indices=new_global_expert_indices[layer].tolist(), - ep_group=ep_group, + recv_metadata=recv_metadata, + new_indices=new_global_expert_indices_cpu[layer_idx], + ep_rank=ep_group.rank(), ) @@ -526,4 +696,4 @@ def _map_new_expert_indices_with_rank_mapping( return mapped_expert_indices -__all__ = ["transfer_layer", "move_from_buffer"] +__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"]