[MoE Refactor][5/N] Isolate zero expert to LongCatFlash (#28891)
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com> Signed-off-by: Dongjie Zou <85092850+baonudesifeizhai@users.noreply.github.com> Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robertgshaw2@gmail.com>
This commit is contained in:
@@ -46,7 +46,7 @@ from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, ZeroExpertFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -179,7 +179,7 @@ class FlashConfig(PretrainedConfig):
|
||||
self.intermediate_size = (
|
||||
self.ffn_hidden_size
|
||||
if hasattr(self, "ffn_hidden_size")
|
||||
else self.intermediate_size
|
||||
else intermediate_size
|
||||
)
|
||||
if hasattr(self, "moe_intermediate_size"):
|
||||
self.moe_intermediate_size = self.moe_intermediate_size
|
||||
@@ -280,10 +280,6 @@ class LongcatMoe(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.zero_expert_num = config.zero_expert_num
|
||||
self.zero_expert_type = config.zero_expert_type
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.enable_eplb = enable_eplb
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.rounter_params_dtype = params_dtype
|
||||
if config.router_dtype == "float32":
|
||||
@@ -291,25 +287,27 @@ class LongcatMoe(nn.Module):
|
||||
|
||||
self.router = LongcatRouter(
|
||||
config=config,
|
||||
zero_expert_num=self.zero_expert_num,
|
||||
zero_expert_num=config.zero_expert_num,
|
||||
rounter_params_dtype=self.rounter_params_dtype,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
assert config.zero_expert_num is not None
|
||||
assert config.zero_expert_type is not None
|
||||
self.experts = ZeroExpertFusedMoE(
|
||||
zero_expert_num=config.zero_expert_num,
|
||||
zero_expert_type=config.zero_expert_type,
|
||||
router=self.router,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=True,
|
||||
params_dtype=params_dtype,
|
||||
e_score_correction_bias=self.router.e_score_correction_bias,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
zero_expert_num=self.zero_expert_num,
|
||||
zero_expert_type=self.zero_expert_type,
|
||||
enable_eplb=self.enable_eplb,
|
||||
enable_eplb=enable_eplb,
|
||||
routed_scaling_factor=config.routed_scaling_factor,
|
||||
)
|
||||
|
||||
@@ -317,11 +315,34 @@ class LongcatMoe(nn.Module):
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
router_logits = self.router(hidden_states.to(self.rounter_params_dtype))
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
# Align to FusedMoE padded hidden size to avoid dim mismatch
|
||||
padded_hidden = self.experts.hidden_size
|
||||
if hidden_dim < padded_hidden:
|
||||
hidden_states_padded = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, padded_hidden - hidden_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
else:
|
||||
hidden_states_padded = hidden_states
|
||||
|
||||
router_logits_full = self.router(
|
||||
hidden_states_padded.to(self.rounter_params_dtype)
|
||||
)
|
||||
|
||||
# ZeroExpertFusedMoE handles routing memoization and zero expert computation
|
||||
# internally. Pass full router_logits (including zero experts) so that
|
||||
# zero experts can be properly identified in routing.
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states_padded,
|
||||
router_logits=router_logits_full, # Full logits (includes zero experts)
|
||||
)
|
||||
|
||||
# Crop back to original hidden dimension if padded earlier
|
||||
if padded_hidden != hidden_dim:
|
||||
final_hidden_states = final_hidden_states[..., :hidden_dim]
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
@@ -419,6 +440,7 @@ class FlashDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn[0](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
llama_4_scaling=None,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm[0](
|
||||
@@ -438,6 +460,7 @@ class FlashDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn[1](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
llama_4_scaling=None,
|
||||
)
|
||||
hidden_states, residual = self.post_attention_layernorm[1](
|
||||
hidden_states, residual
|
||||
|
||||
Reference in New Issue
Block a user