diff --git a/tests/kernels/moe/test_routing.py b/tests/kernels/moe/test_routing.py index 47c0fb8a2..7b065cf15 100644 --- a/tests/kernels/moe/test_routing.py +++ b/tests/kernels/moe/test_routing.py @@ -8,6 +8,9 @@ import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.fused_moe.router.base_router import ( + eplb_map_to_physical_and_record, +) from vllm.model_executor.layers.fused_moe.router.router_factory import ( create_fused_moe_router, ) @@ -55,11 +58,13 @@ def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerSta logical_replica_count = torch.ones( global_num_experts, dtype=torch.int64, device="cuda" ) + should_record_tensor = torch.ones((), dtype=torch.bool, device="cuda") return EplbLayerState( expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count, + should_record_tensor=should_record_tensor, ) @@ -581,3 +586,152 @@ def test_custom( # hidden_states, router_logits = make_test_data(m, k, global_num_experts) # topk_weights, topk_ids = router.select_experts(hidden_states, router_logits) + + +# --------------------------------------------------------------------------- +# Tests for eplb_map_to_physical_and_record +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("record_enabled", [True, False]) +@pytest.mark.parametrize( + "l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load", + [ + pytest.param( + # logical i → physical i + [[0], [1], [2], [3]], + [1, 1, 1, 1], + 4, + [[0, 1], [2, 3], [0, 2]], + [[0, 1], [2, 3], [0, 2]], + [2, 1, 2, 1], + id="identity", + ), + pytest.param( + # logical 0→3, 1→0, 2→1, 3→2 + [[3], [0], [1], [2]], + [1, 1, 1, 1], + 4, + [[0, 1], [2, 3], [0, 2]], + [[3, 0], [1, 2], [3, 1]], + [1, 2, 1, 2], + id="shuffled", + ), + pytest.param( + # logical 0→5, 1→2, 2→7, 3→0 in a larger physical space + [[5], [2], [7], [0]], + [1, 1, 1, 1], + 8, + [[0, 1], [2, 3]], + [[5, 2], [7, 0]], + [1, 0, 1, 0, 0, 1, 0, 1], + id="sparse", + ), + ], +) +def test_eplb_map_no_redundancy( + record_enabled, + l2p_map, + replica_count, + num_physical, + topk_ids, + expected_out, + expected_load, +): + l2p = torch.tensor(l2p_map, dtype=torch.int64, device="cuda") + rc = torch.tensor(replica_count, dtype=torch.int64, device="cuda") + load = torch.zeros(num_physical, dtype=torch.int32, device="cuda") + rec = torch.tensor(record_enabled, dtype=torch.bool, device="cuda") + ids = torch.tensor(topk_ids, dtype=torch.int32, device="cuda") + + out = eplb_map_to_physical_and_record( + topk_ids=ids, + expert_load_view=load, + logical_to_physical_map=l2p, + logical_replica_count=rc, + record_enabled=rec, + ) + + exp_out = torch.tensor(expected_out, dtype=out.dtype, device="cuda") + torch.testing.assert_close(out, exp_out) + + if record_enabled: + exp_load = torch.tensor(expected_load, dtype=torch.int32, device="cuda") + torch.testing.assert_close(load, exp_load) + else: + assert load.sum().item() == 0 + + +@pytest.mark.parametrize("record_enabled", [True, False]) +@pytest.mark.parametrize( + "l2p_map, replica_count, num_physical, topk_ids, expected_out, expected_load", + [ + pytest.param( + # experts 0,1 have 2 replicas; 2,3 have 1 + [[0, 4], [1, 5], [2, -1], [3, -1]], + [2, 2, 1, 1], + 6, + [[0, 1], [2, 3], [0, 2]], + # offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%1=0→p2, + # 3→3%1=0→p3, 4→4%2=0→p0, 5→5%1=0→p2 + [[0, 5], [2, 3], [0, 2]], + [2, 0, 2, 1, 0, 1], + id="partial", + ), + pytest.param( + # all 4 experts have 2 replicas + [[0, 4], [1, 5], [2, 6], [3, 7]], + [2, 2, 2, 2], + 8, + [[0, 1], [2, 3], [0, 2]], + # offs: 0→0%2=0→p0, 1→1%2=1→p5, 2→2%2=0→p2, + # 3→3%2=1→p7, 4→4%2=0→p0, 5→5%2=1→p6 + [[0, 5], [2, 7], [0, 6]], + [2, 0, 1, 0, 0, 1, 1, 1], + id="full", + ), + pytest.param( + # expert 0: 4 replicas, experts 1,2: 2 replicas + [[0, 3, 5, 7], [1, 4, -1, -1], [2, 6, -1, -1]], + [4, 2, 2], + 8, + [[0, 1], [2, 0], [1, 2]], + # offs: 0→0%4=0→p0, 1→1%2=1→p4, 2→2%2=0→p2, + # 3→3%4=3→p7, 4→4%2=0→p1, 5→5%2=1→p6 + [[0, 4], [2, 7], [1, 6]], + [1, 1, 1, 0, 1, 0, 1, 1], + id="uneven", + ), + ], +) +def test_eplb_map_with_redundancy( + record_enabled, + l2p_map, + replica_count, + num_physical, + topk_ids, + expected_out, + expected_load, +): + l2p = torch.tensor(l2p_map, dtype=torch.int64, device="cuda") + rc = torch.tensor(replica_count, dtype=torch.int64, device="cuda") + load = torch.zeros(num_physical, dtype=torch.int32, device="cuda") + rec = torch.tensor(record_enabled, dtype=torch.bool, device="cuda") + ids = torch.tensor(topk_ids, dtype=torch.int32, device="cuda") + + out = eplb_map_to_physical_and_record( + topk_ids=ids, + expert_load_view=load, + logical_to_physical_map=l2p, + logical_replica_count=rc, + record_enabled=rec, + ) + + exp_out = torch.tensor(expected_out, dtype=out.dtype, device="cuda") + torch.testing.assert_close(out, exp_out) + + if record_enabled: + exp_load = torch.tensor(expected_load, dtype=torch.int32, device="cuda") + torch.testing.assert_close(load, exp_load) + else: + assert load.sum().item() == 0 diff --git a/tests/model_executor/test_routed_experts_capture.py b/tests/model_executor/test_routed_experts_capture.py index 45bf4bcac..f831c8dfc 100644 --- a/tests/model_executor/test_routed_experts_capture.py +++ b/tests/model_executor/test_routed_experts_capture.py @@ -62,6 +62,7 @@ def test_base_router_capture_with_eplb_enabled(): router.eplb_state.expert_load_view = torch.zeros(32, dtype=torch.int64) router.eplb_state.logical_to_physical_map = torch.arange(32).view(32, 1) router.eplb_state.logical_replica_count = torch.ones(32, dtype=torch.int64) + router.eplb_state.should_record_tensor = torch.ones((), dtype=torch.bool) captured = [] diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 8afff3af2..cab415a78 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -53,9 +53,9 @@ All2AllBackend = Literal[ class EPLBConfig: """Configuration for Expert Parallel Load Balancing (EP).""" - window_size: int = 1000 + window_size: int = Field(default=1000, gt=0) """Window size for expert load recording.""" - step_interval: int = 3000 + step_interval: int = Field(default=3000, gt=0) """ Interval for rearranging experts in expert parallelism. @@ -71,7 +71,7 @@ class EPLBConfig: Log the balancedness each step of expert parallelism. This is turned off by default since it will cause communication overhead. """ - log_balancedness_interval: int = 1 + log_balancedness_interval: int = Field(default=1, gt=0) """ Interval for logging the balancedness. """ diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py index 8b05c58ea..5f54beedb 100644 --- a/vllm/distributed/elastic_ep/elastic_execute.py +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -399,6 +399,7 @@ class ElasticEPScalingExecutor: eplb_model_state.logical_to_physical_map, eplb_model_state.logical_replica_count, ) + eplb_state._init_should_record_tensor(model) model.update_physical_experts_metadata( num_physical_experts=num_physical_experts, num_local_physical_experts=num_local_experts, diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 180c12abc..be7c3de4d 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -272,6 +272,13 @@ class EplbState: Interval for expert rearrangement steps. This is a constant and is taken from the config. """ + self.should_record_tensor: torch.Tensor | None = None + """ + Shared scalar bool tensor for all layers. Every + :class:`EplbLayerState` holds a reference to the **same** object so + a single ``.fill_()`` updates all layers at once. Allocated on the + first call to :meth:`_init_should_record_tensor`. + """ self.is_async: bool = False """ The flag indicates whether the EPLB is running in async mode. @@ -462,7 +469,7 @@ class EplbState: logical_to_physical_map, logical_replica_count, ) - + self._init_should_record_tensor(model) expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]] model_state = EplbModelState( @@ -582,15 +589,18 @@ class EplbState: # Update the expert load sliding window if not is_dummy: + should_record = self._should_record_current_step(log_stats=log_stats) for eplb_model_state in self.model_states.values(): - eplb_model_state.expert_load_window[self.expert_load_window_step] = ( - eplb_model_state.expert_load_pass.clone() - ) - eplb_model_state.expert_load_pass.zero_() + if should_record: + eplb_model_state.expert_load_window[ + self.expert_load_window_step + ].copy_(eplb_model_state.expert_load_pass) + eplb_model_state.expert_load_pass.zero_() - self.expert_load_window_step += 1 - if self.expert_load_window_step >= self.expert_load_window_size: - self.expert_load_window_step = 0 + if should_record: + self.expert_load_window_step += 1 + if self.expert_load_window_step >= self.expert_load_window_size: + self.expert_load_window_step = 0 # Step the expert rearrangement step # Note that even if this is a dummy step, we still increment the @@ -617,11 +627,66 @@ class EplbState: eplb_model_state.rebalanced for eplb_model_state in self.model_states.values() ): - # Still performing asynchronous rearrangement + # Still performing asynchronous rearrangement; update + # should_record (step > step_interval, so always True) and + # bail out before the step counter is reset. + self._update_layer_should_record(log_stats=log_stats) return self.expert_rearrangement_step = 0 self.rearrange() + self._update_layer_should_record(log_stats=log_stats) + + def _should_record_current_step(self, log_stats: bool = False) -> bool: + """Return whether expert-load recording should be enabled this step. + + Recording is enabled when we are close to either: + 1) The next rearrangement step, so the sliding window is ready. + 2) The next balancedness logging step, when log_stats is enabled. + """ + steps_remaining = ( + self.expert_rearrangement_step_interval - self.expert_rearrangement_step + ) + should_record_for_rearrange = steps_remaining <= self.expert_load_window_size + + if not log_stats: + return should_record_for_rearrange + + log_interval = self.parallel_config.eplb_config.log_balancedness_interval + steps_until_next_log = ( + log_interval - (self.expert_rearrangement_step % log_interval) + ) % log_interval + should_record_for_log = steps_until_next_log <= self.expert_load_window_size + return should_record_for_rearrange or should_record_for_log + + def _update_layer_should_record(self, log_stats: bool = False) -> None: + """Update the shared ``should_record_tensor`` for all layers.""" + if self.should_record_tensor is not None: + self.should_record_tensor.fill_( + self._should_record_current_step(log_stats=log_stats) + ) + + def _init_should_record_tensor(self, model: "MixtureOfExperts") -> None: # type: ignore[name-defined] + """Allocate (once) and propagate the shared ``should_record_tensor``. + + Must be called after :meth:`model.set_eplb_state` so that each + layer's ``eplb_state`` is already populated with the tensor views. + """ + layer_states = [ + layer.eplb_state + for layer in model.moe_layers + if hasattr(layer, "eplb_state") + and isinstance(layer.eplb_state, EplbLayerState) + ] + + if self.should_record_tensor is None and layer_states: + self.should_record_tensor = torch.ones( + (), dtype=torch.bool, device=self.device + ) + + for ls in layer_states: + ls.should_record_tensor = self.should_record_tensor + def rearrange( self, is_profile: bool = False, @@ -993,6 +1058,17 @@ class EplbLayerState: expert_load_view: torch.Tensor | None = None logical_to_physical_map: torch.Tensor | None = None logical_replica_count: torch.Tensor | None = None + should_record_tensor: torch.Tensor | None = None + """ + Shared scalar bool tensor controlling whether to accumulate expert load + metrics during this forward pass. All layers reference the **same** + tensor object, which is owned and updated by :class:`EplbState`. + + Set to ``False`` for the first ``step_interval - window_size`` steps of + each rearrangement period: those steps would be overwritten in the + sliding window before the next rearrangement, so recording them wastes + GPU work. + """ def _node_count_with_rank_mapping( diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index 6332827d1..bcc66887f 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -10,61 +10,49 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton if current_platform.is_cuda_alike(): - @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) - def eplb_map_to_physical_and_record( - topk_ids: torch.Tensor, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> torch.Tensor: - """ - Map the logical expert ids to physical expert ids - and record the expert load metrics. + @triton.jit + def _eplb_map_and_record_i32_kernel( + topk_ids_ptr, + logical_replica_count_ptr, + logical_to_physical_ptr, + out_ids_ptr, + out_ptr, + record_enabled_ptr, + num_logical_experts, + map_slots, + out_size, + numel, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < numel - This will select a pseudo-random replica for each logical expert. - Only used for EPLB. - - Args: - topk_ids: The logical expert ids. - expert_load_view: The expert load view. - logical_to_physical_map: The logical to physical map. - logical_replica_count: The logical replica count. - - Returns: - The physical expert ids. - """ + expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int64) + valid_expert = (expert_id >= 0) & (expert_id < num_logical_experts) + safe_expert_id = tl.where(valid_expert, expert_id, 0) # 1. Convert the logical expert ids to physical expert ids # Directly select a random replica for each logical expert - - # In case `indices_type` is not `torch.long` or `torch.int`, - # e.g. `torch.uint32` as required by dispatch/combine kernels - topk_ids_long = topk_ids.long() - # Use (token position) modulo (replica count) - # to deterministically choose a replica - replica_count = logical_replica_count[topk_ids_long] - # Flatten-position based index, reshaped back to `topk_ids` shape - pos_indices = torch.arange( - topk_ids.numel(), device=topk_ids.device, dtype=torch.long - ).reshape_as(topk_ids) - # Compute pseudo-random indices by modulo - replica_indices = (pos_indices % replica_count).unsqueeze(-1) - physical_ids = ( - logical_to_physical_map[topk_ids_long] - .gather(-1, replica_indices) - .squeeze(-1) + replica_count = tl.load( + logical_replica_count_ptr + safe_expert_id, + mask=mask & valid_expert, + other=1, ) - - topk_ids = physical_ids + # Avoid invalid modulo/div by forcing at least 1. + replica_count = tl.maximum(replica_count, 1) + # Match torch.compile path: use flattened token position. + replica_idx = offs % replica_count # 2. Record expert load metrics. # TODO(bowen): When using `FusedMoEModularKernel`, this # can be done in a more unified way, since - # `FusedMoEPrepareAndFinalizeModular` will return the expert + # `FusedMoEPrepareAndFinalize` will return the expert # token count, in some cases directly from the kernel. # However, now there are many code paths not using # the modular kernel, e.g. calling `fused_experts`, @@ -73,17 +61,63 @@ if current_platform.is_cuda_alike(): # If later refactor moved all the MoE kernel calls # to the modular kernel, we can move this logic there # to achieve better efficiency. - - # `expert_load_view`: (num_physical_experts,) - - # `torch.bincount` is not compilable, so use `scatter_add_` instead. - topk_ids_flatten = topk_ids.flatten() - expert_load_view.scatter_add_( - dim=0, - index=topk_ids_flatten.long(), - src=torch.ones_like(topk_ids_flatten).to(expert_load_view), + map_index = safe_expert_id * map_slots + replica_idx + physical_id = tl.load( + logical_to_physical_ptr + map_index, + mask=mask & valid_expert, + other=-1, + ) + tl.store(out_ids_ptr + offs, physical_id, mask=mask) + + record_enabled = tl.load(record_enabled_ptr) != 0 + valid = mask & record_enabled & (physical_id >= 0) & (physical_id < out_size) + safe_physical_id = tl.where(physical_id >= 0, physical_id, 0) + tl.atomic_add(out_ptr + safe_physical_id, 1, mask=valid) + + def _eplb_map_and_record_triton( + topk_ids: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + expert_load_view: torch.Tensor, + record_enabled: torch.Tensor, + ) -> torch.Tensor: + topk_ids_in = topk_ids.contiguous().to(dtype=torch.int32) + numel = topk_ids_in.numel() + if numel == 0: + return topk_ids + out_flat = torch.empty((numel,), device=topk_ids.device, dtype=topk_ids.dtype) + grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) + assert expert_load_view.is_contiguous() + _eplb_map_and_record_i32_kernel[grid]( + topk_ids_in, + logical_replica_count.contiguous(), + logical_to_physical_map.contiguous(), + out_flat, + expert_load_view, + record_enabled, + logical_replica_count.shape[0], + logical_to_physical_map.shape[1], + expert_load_view.shape[0], + numel, + BLOCK_SIZE=256, + ) + return out_flat.reshape(topk_ids.shape) + + def eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + record_enabled: torch.Tensor, + ) -> torch.Tensor: + # Fused triton implementation: mapping + optional recording in one kernel. + return _eplb_map_and_record_triton( + topk_ids=topk_ids, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + expert_load_view=expert_load_view, + record_enabled=record_enabled, ) - return topk_ids else: def eplb_map_to_physical_and_record( @@ -91,8 +125,8 @@ else: expert_load_view: torch.Tensor, logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, + record_enabled: torch.Tensor, ) -> torch.Tensor: - # CPU fallback: no EPLB so just return as is return topk_ids @@ -146,6 +180,10 @@ class BaseRouter(FusedMoERouter): raise ValueError( "enable_eplb=True requires logical_replica_count != None" ) + if self.eplb_state.should_record_tensor is None: + raise ValueError( + "enable_eplb=True requires should_record_tensor != None" + ) def _get_indices_type(self) -> torch.dtype | None: """Get the desired indices dtype from the getter function.""" @@ -159,11 +197,13 @@ class BaseRouter(FusedMoERouter): assert self.eplb_state.expert_load_view is not None assert self.eplb_state.logical_to_physical_map is not None assert self.eplb_state.logical_replica_count is not None + assert self.eplb_state.should_record_tensor is not None return eplb_map_to_physical_and_record( topk_ids=topk_ids, - expert_load_view=self.eplb_state.expert_load_view, logical_to_physical_map=self.eplb_state.logical_to_physical_map, logical_replica_count=self.eplb_state.logical_replica_count, + expert_load_view=self.eplb_state.expert_load_view, + record_enabled=self.eplb_state.should_record_tensor, ) return topk_ids