[Models] Add SharedFusedMoE support to Qwen3MoE (#32082)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-24 15:36:31 +08:00
committed by GitHub
parent 5c86a89805
commit 8edaf38570

View File

@@ -29,6 +29,7 @@ from itertools import islice
from typing import Any
import torch
import torch.nn.functional as F
from torch import nn
from vllm.attention.layer import Attention
@@ -42,7 +43,7 @@ from vllm.distributed import (
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -86,6 +87,7 @@ class Qwen3MoeMLP(nn.Module):
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
expert_gate: torch.nn.Linear | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -109,12 +111,17 @@ class Qwen3MoeMLP(nn.Module):
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
self.expert_gate = expert_gate
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
out = self.act_fn(gate_up)
out, _ = self.down_proj(out)
if self.expert_gate is not None:
out = F.sigmoid(self.expert_gate(x)[0]) * out
return out
class Qwen3MoeSparseMoeBlock(nn.Module):
@@ -159,20 +166,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.physical_expert_start + self.n_local_physical_experts
)
self.experts = FusedMoE(
num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=True,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
@@ -181,6 +174,46 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
prefix=f"{prefix}.gate",
)
shared_expert_intermediate_size = getattr(
config, "shared_expert_intermediate_size", 0
)
if shared_expert_intermediate_size > 0:
self.shared_expert_gate = ReplicatedLinear(
config.hidden_size,
1,
bias=False,
quant_config=None,
prefix=f"{prefix}.shared_expert_gate",
)
self.shared_expert = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=shared_expert_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
expert_gate=self.shared_expert_gate,
prefix=f"{prefix}.shared_expert",
)
else:
self.shared_expert_gate = None
self.shared_expert = None
self.experts = SharedFusedMoE(
shared_experts=self.shared_expert,
gate=self.gate,
num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert hidden_states.dim() <= 2, (
"Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
@@ -194,15 +227,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
shared_out, fused_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
final_hidden_states = (
shared_out + fused_out if shared_out is not None else fused_out
)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
# return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
@@ -467,7 +507,7 @@ class Qwen3MoeModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
return SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",