[MoE Refactor][5/N] Isolate zero expert to LongCatFlash (#28891)

Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
Signed-off-by: Dongjie Zou <85092850+baonudesifeizhai@users.noreply.github.com>
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robertgshaw2@gmail.com>
This commit is contained in:
baonudesifeizhai
2025-12-20 13:22:04 -05:00
committed by GitHub
parent 560ae9638c
commit 54c8924384
19 changed files with 264 additions and 109 deletions

View File

@@ -46,7 +46,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE, ZeroExpertFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -179,7 +179,7 @@ class FlashConfig(PretrainedConfig):
self.intermediate_size = (
self.ffn_hidden_size
if hasattr(self, "ffn_hidden_size")
else self.intermediate_size
else intermediate_size
)
if hasattr(self, "moe_intermediate_size"):
self.moe_intermediate_size = self.moe_intermediate_size
@@ -280,10 +280,6 @@ class LongcatMoe(nn.Module):
):
super().__init__()
self.hidden_size = hidden_size
self.zero_expert_num = config.zero_expert_num
self.zero_expert_type = config.zero_expert_type
self.routed_scaling_factor = config.routed_scaling_factor
self.enable_eplb = enable_eplb
# Gate always runs at half / full precision for now.
self.rounter_params_dtype = params_dtype
if config.router_dtype == "float32":
@@ -291,25 +287,27 @@ class LongcatMoe(nn.Module):
self.router = LongcatRouter(
config=config,
zero_expert_num=self.zero_expert_num,
zero_expert_num=config.zero_expert_num,
rounter_params_dtype=self.rounter_params_dtype,
prefix=f"{prefix}.gate",
)
self.experts = FusedMoE(
assert config.zero_expert_num is not None
assert config.zero_expert_type is not None
self.experts = ZeroExpertFusedMoE(
zero_expert_num=config.zero_expert_num,
zero_expert_type=config.zero_expert_type,
router=self.router,
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
reduce_results=True,
params_dtype=params_dtype,
e_score_correction_bias=self.router.e_score_correction_bias,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts",
zero_expert_num=self.zero_expert_num,
zero_expert_type=self.zero_expert_type,
enable_eplb=self.enable_eplb,
enable_eplb=enable_eplb,
routed_scaling_factor=config.routed_scaling_factor,
)
@@ -317,11 +315,34 @@ class LongcatMoe(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.router(hidden_states.to(self.rounter_params_dtype))
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
# Align to FusedMoE padded hidden size to avoid dim mismatch
padded_hidden = self.experts.hidden_size
if hidden_dim < padded_hidden:
hidden_states_padded = torch.nn.functional.pad(
hidden_states,
(0, padded_hidden - hidden_dim),
mode="constant",
value=0.0,
)
else:
hidden_states_padded = hidden_states
router_logits_full = self.router(
hidden_states_padded.to(self.rounter_params_dtype)
)
# ZeroExpertFusedMoE handles routing memoization and zero expert computation
# internally. Pass full router_logits (including zero experts) so that
# zero experts can be properly identified in routing.
final_hidden_states = self.experts(
hidden_states=hidden_states_padded,
router_logits=router_logits_full, # Full logits (includes zero experts)
)
# Crop back to original hidden dimension if padded earlier
if padded_hidden != hidden_dim:
final_hidden_states = final_hidden_states[..., :hidden_dim]
return final_hidden_states.view(num_tokens, hidden_dim)
@@ -419,6 +440,7 @@ class FlashDecoderLayer(nn.Module):
hidden_states = self.self_attn[0](
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=None,
)
hidden_states, residual = self.post_attention_layernorm[0](
@@ -438,6 +460,7 @@ class FlashDecoderLayer(nn.Module):
hidden_states = self.self_attn[1](
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=None,
)
hidden_states, residual = self.post_attention_layernorm[1](
hidden_states, residual