[MoE Refactor][5/N] Isolate zero expert to LongCatFlash (#28891)

Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
Signed-off-by: Dongjie Zou <85092850+baonudesifeizhai@users.noreply.github.com>
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robertgshaw2@gmail.com>
This commit is contained in:
baonudesifeizhai
2025-12-20 13:22:04 -05:00
committed by GitHub
parent 560ae9638c
commit 54c8924384
19 changed files with 264 additions and 109 deletions

View File

@@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method
topk_weights, topk_ids, _ = fused_moe.select_experts(
topk_weights, topk_ids = fused_moe.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)

View File

@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import (
ZeroExpertFusedMoE,
)
from vllm.triton_utils import HAS_TRITON
_config: dict[str, Any] | None = None
@@ -54,6 +57,7 @@ __all__ = [
"FusedMoEPrepareAndFinalize",
"RoutingMethodType",
"SharedFusedMoE",
"ZeroExpertFusedMoE",
"activation_without_mul",
"override_config",
"get_config",

View File

@@ -92,7 +92,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -110,10 +110,4 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
expert_map=None if self.disable_expert_map else layer.expert_map,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
return result, zero_expert_result
else:
return result
return result

View File

@@ -32,7 +32,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
@@ -350,8 +349,6 @@ class FusedMoE(CustomOp):
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
zero_expert_num: int | None = 0,
zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
@@ -409,8 +406,6 @@ class FusedMoE(CustomOp):
self.global_num_experts = num_experts + num_redundant_experts
self.logical_num_experts = num_experts
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
# Expert mapping used in self.load_weights
self.expert_mapping = expert_mapping
@@ -1525,15 +1520,15 @@ class FusedMoE(CustomOp):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids, zero_expert_result)
(tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
The weights, expert ids, and zero expert computation result.
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
@@ -1655,23 +1650,7 @@ class FusedMoE(CustomOp):
assert topk_ids.dtype == indices_type or indices_type is None
# Compute zero expert result if needed
if (
self.zero_expert_num is not None
and self.zero_expert_num > 0
and self.zero_expert_type is not None
and self.global_num_experts is not None
):
zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=self.global_num_experts,
zero_expert_type=self.zero_expert_type,
hidden_states=hidden_states,
)
else:
zero_expert_result = None
return topk_weights, topk_ids, zero_expert_result
return topk_weights, topk_ids
def must_reduce_shared_expert_outputs(self) -> bool:
"""
@@ -1736,14 +1715,7 @@ class FusedMoE(CustomOp):
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, self.layer_name
)
if self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(fused_output, tuple)
fused_output, zero_expert_result = fused_output
return (reduce_output(fused_output) + zero_expert_result)[
..., :og_hidden_states
]
else:
return reduce_output(fused_output)[..., :og_hidden_states]
return reduce_output(fused_output)[..., :og_hidden_states]
else:
if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
@@ -1841,13 +1813,6 @@ class FusedMoE(CustomOp):
final_hidden_states,
)
if self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
assert self.shared_experts is None
final_hidden_states, zero_expert_result = final_hidden_states
if zero_expert_result is not None:
final_hidden_states += zero_expert_result
if not skip_result_store:
if self.shared_experts is None:
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
@@ -2030,9 +1995,6 @@ class FusedMoE(CustomOp):
shared_output,
final_hidden_states,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
@@ -2051,9 +2013,6 @@ class FusedMoE(CustomOp):
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return (combine_output(final_hidden_states), zero_expert_result)
else:
return combine_output(final_hidden_states)

View File

@@ -295,7 +295,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -336,13 +336,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=layer.expert_map,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
return result, zero_expert_result
else:
return result
return result
def forward_cpu(
self,

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import torch
from torch import nn
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
class ZeroExpertFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of zero experts.
Zero experts perform identity operations (scaled pass-through) instead
of full MLP computations.
This class uses memoization to avoid redundant routing computation:
routing is computed once and reused for both zero expert computation
and the main FusedMoE forward pass.
"""
def __init__(
self,
zero_expert_num: int,
zero_expert_type: str,
router: nn.Module,
**kwargs,
):
# ZeroExpertFusedMoE manages its own custom_routing_function for memoization
assert (
"custom_routing_function" not in kwargs
or kwargs.get("custom_routing_function") is None
), (
"ZeroExpertFusedMoE does not support external custom_routing_function. "
"It manages its own for routing memoization."
)
# Automatically slice router's e_score_correction_bias to only include
# real experts (not zero_experts) for the base FusedMoE.
# The full bias will be used temporarily in forward() for routing.
if hasattr(router, "e_score_correction_bias") and "num_experts" in kwargs:
num_real_experts = kwargs["num_experts"]
router_bias = router.e_score_correction_bias
user_bias = kwargs.get("e_score_correction_bias")
# Use router's bias if:
# 1. User didn't provide bias, or
# 2. User provided full bias (same size as router)
if user_bias is None or user_bias.shape[0] == router_bias.shape[0]:
kwargs["e_score_correction_bias"] = router_bias[:num_real_experts]
# FusedMoE no longer accepts zero_expert_num/zero_expert_type.
# We handle zero experts ourselves in forward().
super().__init__(**kwargs)
# Store the actual zero_expert_num and zero_expert_type for our own use
self._actual_zero_expert_num = zero_expert_num
self._actual_zero_expert_type = zero_expert_type
self._router = router # Full router (includes zero experts)
# Expose zero_expert_num and zero_expert_type as attributes for
# compatibility with quantization methods that check these attributes
self.zero_expert_num = 0
self.zero_expert_type = None
# Memoization state for routing results
self._memoized_topk_weights: torch.Tensor | None = None
self._memoized_topk_ids: torch.Tensor | None = None
# Create custom_routing_function to reuse memoized routing results
def custom_routing_function(hidden_states, gating_output, topk, renormalize):
"""Return memoized `topk_weights` and `topk_ids`."""
if self._memoized_topk_weights is None or self._memoized_topk_ids is None:
raise RuntimeError(
"ZeroExpertFusedMoE: routing results not memoized. "
"Call select_experts first to compute routing."
)
return self._memoized_topk_weights, self._memoized_topk_ids
self.custom_routing_function = custom_routing_function
@contextmanager
def _temporarily_set_attrs(self, **attrs):
"""
Temporarily set attributes using object.__setattr__ and restore them.
This bypasses nn.Module.__setattr__ to avoid Dynamo tracing issues.
When PyTorch Dynamo traces the forward pass, it cannot handle
nn.Module.__setattr__ calls (which include parameter registration logic),
resulting in "Unsupported" errors. Using object.__setattr__ directly
sets the attribute without triggering nn.Module's custom __setattr__,
allowing Dynamo to trace the code successfully.
"""
originals = {key: getattr(self, key) for key in attrs}
try:
for key, value in attrs.items():
object.__setattr__(self, key, value)
yield
finally:
for key, value in originals.items():
object.__setattr__(self, key, value)
def _compute_zero_expert_result(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | None:
"""Compute zero expert results using pre-computed routing."""
if (
self._actual_zero_expert_num is None
or self._actual_zero_expert_num <= 0
or self._actual_zero_expert_type is None
):
return None
return zero_experts_compute_triton(
expert_indices=topk_ids.clone(),
expert_scales=topk_weights.clone(),
num_experts=self.logical_num_experts,
zero_expert_type=self._actual_zero_expert_type,
hidden_states=hidden_states,
)
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor, # Full logits including zero experts
) -> torch.Tensor:
"""
Forward pass with zero expert support and routing memoization.
Args:
hidden_states: Input hidden states
router_logits: Full router logits (including zero experts)
Returns:
Combined output from real experts and zero experts
"""
# Prepare temporary attribute overrides for routing computation
temp_attrs = {
"custom_routing_function": None, # Disable for first routing
}
if self._router is not None:
temp_attrs["e_score_correction_bias"] = self._router.e_score_correction_bias
# Compute routing with temporary attributes
# Pass full router_logits (including zero experts) so that zero experts
# can be properly identified in topk_ids
with self._temporarily_set_attrs(**temp_attrs):
topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states,
router_logits=router_logits, # Full logits (includes zero experts)
)
# Compute zero expert result if needed
zero_expert_result = self._compute_zero_expert_result(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Memoize routing results for reuse in super().forward()
self._memoized_topk_weights = topk_weights
self._memoized_topk_ids = topk_ids
# Slice router_logits for real experts only
router_logits_sliced = router_logits[..., : self.logical_num_experts]
# Compute real expert results (will reuse memoized routing via
# custom_routing_function)
# zero_expert_num is already 0, so FusedMoE won't handle zero experts
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits_sliced,
)
# Combine results
# Both zero_expert_result and fused_out are computed from the same
# hidden_states, so they should be on the same device.
if zero_expert_result is not None:
fused_out = fused_out + zero_expert_result
# Clear memoization after use
self._memoized_topk_weights = None
self._memoized_topk_ids = None
return fused_out

View File

@@ -764,7 +764,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -500,7 +500,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -574,7 +574,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=layer.e_score_correction_bias,
)
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1166,7 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1403,7 +1403,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1765,7 +1765,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
f"{layer.activation} not supported for Marlin MoE."
)
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1991,7 +1991,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -2607,7 +2607,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
)
assert self.moe_quant_config is not None
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -142,7 +142,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -1292,13 +1292,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
select_result = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
topk_weights, topk_ids, zero_expert_result = select_result
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts,
@@ -1353,13 +1351,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
return result, zero_expert_result
else:
return result
return result
class Fp8OnlineMoEMethod(Fp8MoEMethod):

View File

@@ -639,7 +639,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method."
)
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -900,7 +900,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -796,7 +796,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
# Expert selection
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -1599,7 +1599,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x_routing,
router_logits=router_logits,
)

View File

@@ -370,7 +370,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -896,7 +896,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -990,7 +990,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -338,7 +338,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -530,7 +530,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
@@ -738,7 +738,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -359,7 +359,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, _ = layer.select_experts(
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)

View File

@@ -46,7 +46,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE, ZeroExpertFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -179,7 +179,7 @@ class FlashConfig(PretrainedConfig):
self.intermediate_size = (
self.ffn_hidden_size
if hasattr(self, "ffn_hidden_size")
else self.intermediate_size
else intermediate_size
)
if hasattr(self, "moe_intermediate_size"):
self.moe_intermediate_size = self.moe_intermediate_size
@@ -280,10 +280,6 @@ class LongcatMoe(nn.Module):
):
super().__init__()
self.hidden_size = hidden_size
self.zero_expert_num = config.zero_expert_num
self.zero_expert_type = config.zero_expert_type
self.routed_scaling_factor = config.routed_scaling_factor
self.enable_eplb = enable_eplb
# Gate always runs at half / full precision for now.
self.rounter_params_dtype = params_dtype
if config.router_dtype == "float32":
@@ -291,25 +287,27 @@ class LongcatMoe(nn.Module):
self.router = LongcatRouter(
config=config,
zero_expert_num=self.zero_expert_num,
zero_expert_num=config.zero_expert_num,
rounter_params_dtype=self.rounter_params_dtype,
prefix=f"{prefix}.gate",
)
self.experts = FusedMoE(
assert config.zero_expert_num is not None
assert config.zero_expert_type is not None
self.experts = ZeroExpertFusedMoE(
zero_expert_num=config.zero_expert_num,
zero_expert_type=config.zero_expert_type,
router=self.router,
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
params_dtype=params_dtype,
e_score_correction_bias=self.router.e_score_correction_bias,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts",
zero_expert_num=self.zero_expert_num,
zero_expert_type=self.zero_expert_type,
enable_eplb=self.enable_eplb,
enable_eplb=enable_eplb,
routed_scaling_factor=config.routed_scaling_factor,
)
@@ -317,11 +315,34 @@ class LongcatMoe(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.router(hidden_states.to(self.rounter_params_dtype))
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
# Align to FusedMoE padded hidden size to avoid dim mismatch
padded_hidden = self.experts.hidden_size
if hidden_dim < padded_hidden:
hidden_states_padded = torch.nn.functional.pad(
hidden_states,
(0, padded_hidden - hidden_dim),
mode="constant",
value=0.0,
)
else:
hidden_states_padded = hidden_states
router_logits_full = self.router(
hidden_states_padded.to(self.rounter_params_dtype)
)
# ZeroExpertFusedMoE handles routing memoization and zero expert computation
# internally. Pass full router_logits (including zero experts) so that
# zero experts can be properly identified in routing.
final_hidden_states = self.experts(
hidden_states=hidden_states_padded,
router_logits=router_logits_full, # Full logits (includes zero experts)
)
# Crop back to original hidden dimension if padded earlier
if padded_hidden != hidden_dim:
final_hidden_states = final_hidden_states[..., :hidden_dim]
return final_hidden_states.view(num_tokens, hidden_dim)
@@ -419,6 +440,7 @@ class FlashDecoderLayer(nn.Module):
hidden_states = self.self_attn[0](
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=None,
)
hidden_states, residual = self.post_attention_layernorm[0](
@@ -438,6 +460,7 @@ class FlashDecoderLayer(nn.Module):
hidden_states = self.self_attn[1](
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=None,
)
hidden_states, residual = self.post_attention_layernorm[1](
hidden_states, residual