[Model] Modify the gate implementation of glm4_moe (#22832)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-08-14 20:28:50 +08:00
committed by GitHub
parent 829b9a62d0
commit 92ff41abea
2 changed files with 11 additions and 11 deletions

View File

@@ -41,7 +41,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -118,14 +117,15 @@ class Glm4MoE(nn.Module):
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
params_dtype=torch.float32,
prefix=f"{prefix}.gate")
# NOTE In the transformers implementation, the gate isn't an nn.Linear,
# so we cannot use ReplicatedLinear here.
# See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260
self.gate = nn.Linear(
config.hidden_size,
config.n_routed_experts,
bias=False,
dtype=torch.float32,
)
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32))
@@ -181,7 +181,7 @@ class Glm4MoE(nn.Module):
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor