From 8edaf3857027c75382672ade255f7cb531f96844 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 24 Jan 2026 15:36:31 +0800 Subject: [PATCH] [Models] Add `SharedFusedMoE` support to Qwen3MoE (#32082) Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen3_moe.py | 80 ++++++++++++++++++------- 1 file changed, 60 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 8e49ccea5..567c03193 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -29,6 +29,7 @@ from itertools import islice from typing import Any import torch +import torch.nn.functional as F from torch import nn from vllm.attention.layer import Attention @@ -42,7 +43,7 @@ from vllm.distributed import ( ) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -86,6 +87,7 @@ class Qwen3MoeMLP(nn.Module): hidden_act: str, quant_config: QuantizationConfig | None = None, reduce_results: bool = True, + expert_gate: torch.nn.Linear | None = None, prefix: str = "", ) -> None: super().__init__() @@ -109,12 +111,17 @@ class Qwen3MoeMLP(nn.Module): f"Unsupported activation: {hidden_act}. Only silu is supported for now." ) self.act_fn = SiluAndMul() + self.expert_gate = expert_gate def forward(self, x): gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x + out = self.act_fn(gate_up) + out, _ = self.down_proj(out) + + if self.expert_gate is not None: + out = F.sigmoid(self.expert_gate(x)[0]) * out + + return out class Qwen3MoeSparseMoeBlock(nn.Module): @@ -159,20 +166,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): self.physical_expert_start + self.n_local_physical_experts ) - self.experts = FusedMoE( - num_experts=self.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=True, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - ) - self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, @@ -181,6 +174,46 @@ class Qwen3MoeSparseMoeBlock(nn.Module): prefix=f"{prefix}.gate", ) + shared_expert_intermediate_size = getattr( + config, "shared_expert_intermediate_size", 0 + ) + if shared_expert_intermediate_size > 0: + self.shared_expert_gate = ReplicatedLinear( + config.hidden_size, + 1, + bias=False, + quant_config=None, + prefix=f"{prefix}.shared_expert_gate", + ) + self.shared_expert = Qwen3MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + expert_gate=self.shared_expert_gate, + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_expert_gate = None + self.shared_expert = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, + gate=self.gate, + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: assert hidden_states.dim() <= 2, ( "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" @@ -194,15 +227,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( + shared_out, fused_out = self.experts( hidden_states=hidden_states, router_logits=router_logits ) + final_hidden_states = ( + shared_out + fused_out if shared_out is not None else fused_out + ) if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states, 0 ) final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states + ) # return to 1d if input is 1d return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states @@ -467,7 +507,7 @@ class Qwen3MoeModel(nn.Module): def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj",