[Models] Add SharedFusedMoE support to Qwen3MoE (#32082)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -29,6 +29,7 @@ from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
@@ -42,7 +43,7 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -86,6 +87,7 @@ class Qwen3MoeMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
reduce_results: bool = True,
|
||||
expert_gate: torch.nn.Linear | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -109,12 +111,17 @@ class Qwen3MoeMLP(nn.Module):
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
self.expert_gate = expert_gate
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
out = self.act_fn(gate_up)
|
||||
out, _ = self.down_proj(out)
|
||||
|
||||
if self.expert_gate is not None:
|
||||
out = F.sigmoid(self.expert_gate(x)[0]) * out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
@@ -159,20 +166,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=self.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
@@ -181,6 +174,46 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
shared_expert_intermediate_size = getattr(
|
||||
config, "shared_expert_intermediate_size", 0
|
||||
)
|
||||
if shared_expert_intermediate_size > 0:
|
||||
self.shared_expert_gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
1,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.shared_expert_gate",
|
||||
)
|
||||
self.shared_expert = Qwen3MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=shared_expert_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
expert_gate=self.shared_expert_gate,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
)
|
||||
else:
|
||||
self.shared_expert_gate = None
|
||||
self.shared_expert = None
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_expert,
|
||||
gate=self.gate,
|
||||
num_experts=self.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
assert hidden_states.dim() <= 2, (
|
||||
"Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
|
||||
@@ -194,15 +227,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
shared_out, fused_out = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
final_hidden_states = (
|
||||
shared_out + fused_out if shared_out is not None else fused_out
|
||||
)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0
|
||||
)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
elif self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||
final_hidden_states
|
||||
)
|
||||
|
||||
# return to 1d if input is 1d
|
||||
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
|
||||
@@ -467,7 +507,7 @@ class Qwen3MoeModel(nn.Module):
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
return SharedFusedMoE.make_expert_params_mapping(
|
||||
self,
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
|
||||
Reference in New Issue
Block a user