diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py index e8826eb44..44cbdeed4 100644 --- a/tests/test_routing_simulator.py +++ b/tests/test_routing_simulator.py @@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device): envs.environment_variables[env_name] = lambda s=strategy: s # Test the select_experts method - topk_weights, topk_ids, _ = fused_moe.select_experts( + topk_weights, topk_ids = fused_moe.select_experts( hidden_states=hidden_states, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 8fee4038b..3d248e7fb 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, ) from vllm.model_executor.layers.fused_moe.utils import activation_without_mul +from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import ( + ZeroExpertFusedMoE, +) from vllm.triton_utils import HAS_TRITON _config: dict[str, Any] | None = None @@ -54,6 +57,7 @@ __all__ = [ "FusedMoEPrepareAndFinalize", "RoutingMethodType", "SharedFusedMoE", + "ZeroExpertFusedMoE", "activation_without_mul", "override_config", "get_config", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 9c9bc2514..30ff1bf2f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -92,7 +92,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids, zero_expert_result = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -110,10 +110,4 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): expert_map=None if self.disable_expert_map else layer.expert_map, ) - if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: - assert not isinstance(result, tuple), ( - "Shared + zero experts are mutually exclusive not yet supported" - ) - return result, zero_expert_result - else: - return result + return result diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6a65b0601..2e7267d56 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -32,7 +32,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, ) @@ -350,8 +349,6 @@ class FusedMoE(CustomOp): num_redundant_experts: int = 0, has_bias: bool = False, is_sequence_parallel=False, - zero_expert_num: int | None = 0, - zero_expert_type: str | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, routing_method_type: int | None = None, @@ -409,8 +406,6 @@ class FusedMoE(CustomOp): self.global_num_experts = num_experts + num_redundant_experts self.logical_num_experts = num_experts - self.zero_expert_num = zero_expert_num - self.zero_expert_type = zero_expert_type # Expert mapping used in self.load_weights self.expert_mapping = expert_mapping @@ -1525,15 +1520,15 @@ class FusedMoE(CustomOp): self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ Route the input hidden states to the top-k experts based on the router logits. Returns: - (topk_weights, topk_ids, zero_expert_result) - (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): - The weights, expert ids, and zero expert computation result. + (topk_weights, topk_ids) + (tuple[torch.Tensor, torch.Tensor]): + The weights and expert ids. **Compatibility**: When EPLB is not enabled, the returned ids are equivalent to global logical ids, so should be compatible with @@ -1655,23 +1650,7 @@ class FusedMoE(CustomOp): assert topk_ids.dtype == indices_type or indices_type is None - # Compute zero expert result if needed - if ( - self.zero_expert_num is not None - and self.zero_expert_num > 0 - and self.zero_expert_type is not None - and self.global_num_experts is not None - ): - zero_expert_result = zero_experts_compute_triton( - expert_indices=topk_ids, - expert_scales=topk_weights, - num_experts=self.global_num_experts, - zero_expert_type=self.zero_expert_type, - hidden_states=hidden_states, - ) - else: - zero_expert_result = None - return topk_weights, topk_ids, zero_expert_result + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: """ @@ -1736,14 +1715,7 @@ class FusedMoE(CustomOp): fused_output = torch.ops.vllm.moe_forward( hidden_states, router_logits, self.layer_name ) - if self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(fused_output, tuple) - fused_output, zero_expert_result = fused_output - return (reduce_output(fused_output) + zero_expert_result)[ - ..., :og_hidden_states - ] - else: - return reduce_output(fused_output)[..., :og_hidden_states] + return reduce_output(fused_output)[..., :og_hidden_states] else: if current_platform.is_tpu() or current_platform.is_cpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we @@ -1841,13 +1813,6 @@ class FusedMoE(CustomOp): final_hidden_states, ) - if self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, tuple) - assert self.shared_experts is None - final_hidden_states, zero_expert_result = final_hidden_states - if zero_expert_result is not None: - final_hidden_states += zero_expert_result - if not skip_result_store: if self.shared_experts is None: full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_( @@ -2030,9 +1995,6 @@ class FusedMoE(CustomOp): shared_output, final_hidden_states, ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, tuple) - final_hidden_states, zero_expert_result = final_hidden_states def combine_output(states: torch.Tensor) -> torch.Tensor: if do_naive_dispatch_combine: @@ -2051,9 +2013,6 @@ class FusedMoE(CustomOp): final_hidden_states[0], combine_output(final_hidden_states[1]), ) - elif self.zero_expert_num is not None and self.zero_expert_num > 0: - assert isinstance(final_hidden_states, torch.Tensor) - return (combine_output(final_hidden_states), zero_expert_result) else: return combine_output(final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 6182f10aa..1ee7b65b2 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -295,7 +295,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids, zero_expert_result = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -336,13 +336,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=layer.expert_map, ) - if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: - assert not isinstance(result, tuple), ( - "Shared + zero experts are mutually exclusive not yet supported" - ) - return result, zero_expert_result - else: - return result + return result def forward_cpu( self, diff --git a/vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py b/vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py new file mode 100644 index 000000000..97d21767f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager + +import torch +from torch import nn + +from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton +from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + +class ZeroExpertFusedMoE(FusedMoE): + """ + A FusedMoE operation that also computes the results of zero experts. + Zero experts perform identity operations (scaled pass-through) instead + of full MLP computations. + + This class uses memoization to avoid redundant routing computation: + routing is computed once and reused for both zero expert computation + and the main FusedMoE forward pass. + """ + + def __init__( + self, + zero_expert_num: int, + zero_expert_type: str, + router: nn.Module, + **kwargs, + ): + # ZeroExpertFusedMoE manages its own custom_routing_function for memoization + assert ( + "custom_routing_function" not in kwargs + or kwargs.get("custom_routing_function") is None + ), ( + "ZeroExpertFusedMoE does not support external custom_routing_function. " + "It manages its own for routing memoization." + ) + + # Automatically slice router's e_score_correction_bias to only include + # real experts (not zero_experts) for the base FusedMoE. + # The full bias will be used temporarily in forward() for routing. + if hasattr(router, "e_score_correction_bias") and "num_experts" in kwargs: + num_real_experts = kwargs["num_experts"] + router_bias = router.e_score_correction_bias + user_bias = kwargs.get("e_score_correction_bias") + + # Use router's bias if: + # 1. User didn't provide bias, or + # 2. User provided full bias (same size as router) + if user_bias is None or user_bias.shape[0] == router_bias.shape[0]: + kwargs["e_score_correction_bias"] = router_bias[:num_real_experts] + + # FusedMoE no longer accepts zero_expert_num/zero_expert_type. + # We handle zero experts ourselves in forward(). + super().__init__(**kwargs) + # Store the actual zero_expert_num and zero_expert_type for our own use + self._actual_zero_expert_num = zero_expert_num + self._actual_zero_expert_type = zero_expert_type + self._router = router # Full router (includes zero experts) + + # Expose zero_expert_num and zero_expert_type as attributes for + # compatibility with quantization methods that check these attributes + self.zero_expert_num = 0 + self.zero_expert_type = None + + # Memoization state for routing results + self._memoized_topk_weights: torch.Tensor | None = None + self._memoized_topk_ids: torch.Tensor | None = None + + # Create custom_routing_function to reuse memoized routing results + def custom_routing_function(hidden_states, gating_output, topk, renormalize): + """Return memoized `topk_weights` and `topk_ids`.""" + if self._memoized_topk_weights is None or self._memoized_topk_ids is None: + raise RuntimeError( + "ZeroExpertFusedMoE: routing results not memoized. " + "Call select_experts first to compute routing." + ) + return self._memoized_topk_weights, self._memoized_topk_ids + + self.custom_routing_function = custom_routing_function + + @contextmanager + def _temporarily_set_attrs(self, **attrs): + """ + Temporarily set attributes using object.__setattr__ and restore them. + + This bypasses nn.Module.__setattr__ to avoid Dynamo tracing issues. + When PyTorch Dynamo traces the forward pass, it cannot handle + nn.Module.__setattr__ calls (which include parameter registration logic), + resulting in "Unsupported" errors. Using object.__setattr__ directly + sets the attribute without triggering nn.Module's custom __setattr__, + allowing Dynamo to trace the code successfully. + """ + originals = {key: getattr(self, key) for key in attrs} + try: + for key, value in attrs.items(): + object.__setattr__(self, key, value) + yield + finally: + for key, value in originals.items(): + object.__setattr__(self, key, value) + + def _compute_zero_expert_result( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor | None: + """Compute zero expert results using pre-computed routing.""" + if ( + self._actual_zero_expert_num is None + or self._actual_zero_expert_num <= 0 + or self._actual_zero_expert_type is None + ): + return None + + return zero_experts_compute_triton( + expert_indices=topk_ids.clone(), + expert_scales=topk_weights.clone(), + num_experts=self.logical_num_experts, + zero_expert_type=self._actual_zero_expert_type, + hidden_states=hidden_states, + ) + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, # Full logits including zero experts + ) -> torch.Tensor: + """ + Forward pass with zero expert support and routing memoization. + + Args: + hidden_states: Input hidden states + router_logits: Full router logits (including zero experts) + + Returns: + Combined output from real experts and zero experts + """ + # Prepare temporary attribute overrides for routing computation + temp_attrs = { + "custom_routing_function": None, # Disable for first routing + } + if self._router is not None: + temp_attrs["e_score_correction_bias"] = self._router.e_score_correction_bias + + # Compute routing with temporary attributes + # Pass full router_logits (including zero experts) so that zero experts + # can be properly identified in topk_ids + with self._temporarily_set_attrs(**temp_attrs): + topk_weights, topk_ids = self.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, # Full logits (includes zero experts) + ) + + # Compute zero expert result if needed + zero_expert_result = self._compute_zero_expert_result( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + + # Memoize routing results for reuse in super().forward() + self._memoized_topk_weights = topk_weights + self._memoized_topk_ids = topk_ids + + # Slice router_logits for real experts only + router_logits_sliced = router_logits[..., : self.logical_num_experts] + + # Compute real expert results (will reuse memoized routing via + # custom_routing_function) + # zero_expert_num is already 0, so FusedMoE won't handle zero experts + fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits_sliced, + ) + + # Combine results + # Both zero_expert_result and fused_out are computed from the same + # hidden_states, so they should be on the same device. + if zero_expert_result is not None: + fused_out = fused_out + zero_expert_result + + # Clear memoization after use + self._memoized_topk_weights = None + self._memoized_topk_ids = None + + return fused_out diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 314848721..602d02d2f 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -764,7 +764,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 1fd959cb3..efe567704 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -500,7 +500,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fc359a306..f4038801c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -574,7 +574,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): e_score_correction_bias=layer.e_score_correction_bias, ) - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1166,7 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1403,7 +1403,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1765,7 +1765,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): f"{layer.activation} not supported for Marlin MoE." ) - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1991,7 +1991,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -2607,7 +2607,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." ) assert self.moe_quant_config is not None - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 11097cf36..56b11b22f 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -142,7 +142,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 30ca64238..4b2438133 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1292,13 +1292,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - select_result = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) - topk_weights, topk_ids, zero_expert_result = select_result - if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_fused_experts, @@ -1353,13 +1351,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: - assert not isinstance(result, tuple), ( - "Shared + zero experts are mutually exclusive not yet supported" - ) - return result, zero_expert_result - else: - return result + return result class Fp8OnlineMoEMethod(Fp8MoEMethod): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9dd734f2f..9600bb422 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -639,7 +639,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): "fused GGUF MoE method." ) - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 347c7b200..d2dafca99 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -900,7 +900,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index aa3937d4c..54e8673fc 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -796,7 +796,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) # Expert selection - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1599,7 +1599,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): x_routing, _ = x else: x_routing = x - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x_routing, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 4bedb951a..513f6f7b2 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -370,7 +370,7 @@ class MoeWNA16Method(FusedMoEMethodBase): from vllm.model_executor.layers.fused_moe import fused_experts assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 832925825..dc0fbfa7d 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -896,7 +896,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): raise NotImplementedError("EPLB is not supported for mxfp4") if self.mxfp4_backend == Mxfp4Backend.MARLIN: - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -990,7 +990,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ): from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 0b9b098af..819704803 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -338,7 +338,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -530,7 +530,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -738,7 +738,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index b2ecb0b17..dce9c661e 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -359,7 +359,7 @@ class RTNMoEMethod(FusedMoEMethodBase): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids, _ = layer.select_experts( + topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index c5441283f..774737387 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -46,7 +46,7 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE, ZeroExpertFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -179,7 +179,7 @@ class FlashConfig(PretrainedConfig): self.intermediate_size = ( self.ffn_hidden_size if hasattr(self, "ffn_hidden_size") - else self.intermediate_size + else intermediate_size ) if hasattr(self, "moe_intermediate_size"): self.moe_intermediate_size = self.moe_intermediate_size @@ -280,10 +280,6 @@ class LongcatMoe(nn.Module): ): super().__init__() self.hidden_size = hidden_size - self.zero_expert_num = config.zero_expert_num - self.zero_expert_type = config.zero_expert_type - self.routed_scaling_factor = config.routed_scaling_factor - self.enable_eplb = enable_eplb # Gate always runs at half / full precision for now. self.rounter_params_dtype = params_dtype if config.router_dtype == "float32": @@ -291,25 +287,27 @@ class LongcatMoe(nn.Module): self.router = LongcatRouter( config=config, - zero_expert_num=self.zero_expert_num, + zero_expert_num=config.zero_expert_num, rounter_params_dtype=self.rounter_params_dtype, prefix=f"{prefix}.gate", ) - self.experts = FusedMoE( + assert config.zero_expert_num is not None + assert config.zero_expert_type is not None + self.experts = ZeroExpertFusedMoE( + zero_expert_num=config.zero_expert_num, + zero_expert_type=config.zero_expert_type, + router=self.router, num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, reduce_results=True, params_dtype=params_dtype, - e_score_correction_bias=self.router.e_score_correction_bias, renormalize=False, quant_config=quant_config, prefix=f"{prefix}.experts", - zero_expert_num=self.zero_expert_num, - zero_expert_type=self.zero_expert_type, - enable_eplb=self.enable_eplb, + enable_eplb=enable_eplb, routed_scaling_factor=config.routed_scaling_factor, ) @@ -317,11 +315,34 @@ class LongcatMoe(nn.Module): num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = self.router(hidden_states.to(self.rounter_params_dtype)) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits + # Align to FusedMoE padded hidden size to avoid dim mismatch + padded_hidden = self.experts.hidden_size + if hidden_dim < padded_hidden: + hidden_states_padded = torch.nn.functional.pad( + hidden_states, + (0, padded_hidden - hidden_dim), + mode="constant", + value=0.0, + ) + else: + hidden_states_padded = hidden_states + + router_logits_full = self.router( + hidden_states_padded.to(self.rounter_params_dtype) ) + # ZeroExpertFusedMoE handles routing memoization and zero expert computation + # internally. Pass full router_logits (including zero experts) so that + # zero experts can be properly identified in routing. + final_hidden_states = self.experts( + hidden_states=hidden_states_padded, + router_logits=router_logits_full, # Full logits (includes zero experts) + ) + + # Crop back to original hidden dimension if padded earlier + if padded_hidden != hidden_dim: + final_hidden_states = final_hidden_states[..., :hidden_dim] + return final_hidden_states.view(num_tokens, hidden_dim) @@ -419,6 +440,7 @@ class FlashDecoderLayer(nn.Module): hidden_states = self.self_attn[0]( positions=positions, hidden_states=hidden_states, + llama_4_scaling=None, ) hidden_states, residual = self.post_attention_layernorm[0]( @@ -438,6 +460,7 @@ class FlashDecoderLayer(nn.Module): hidden_states = self.self_attn[1]( positions=positions, hidden_states=hidden_states, + llama_4_scaling=None, ) hidden_states, residual = self.post_attention_layernorm[1]( hidden_states, residual