[Bugfix] Fix GLM-4 MoE router logits dtype for data parallel chunking (#31055)

Signed-off-by: ReinforcedKnowledge <reinforced.knowledge@gmail.com>
This commit is contained in:
Yakine Tahtah
2026-01-06 18:57:56 +01:00
committed by GitHub
parent 142c4d1738
commit 4e67a8f616
3 changed files with 13 additions and 1 deletions

View File

@@ -1006,6 +1006,9 @@ class FusedMoEConfig:
# The activation type.
in_dtype: torch.dtype
# Defaults to in_dtype if not specified.
router_logits_dtype: torch.dtype | None = None
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
@@ -1022,6 +1025,9 @@ class FusedMoEConfig:
assert self.max_num_tokens > 0
if self.router_logits_dtype is None:
self.router_logits_dtype = self.in_dtype
@property
def tp_size(self):
return self.moe_parallel_config.tp_size

View File

@@ -314,6 +314,7 @@ class FusedMoE(CustomOp):
renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer.
router_logits_dtype: Data type for router logits buffers.
"""
def __init__(
@@ -348,6 +349,7 @@ class FusedMoE(CustomOp):
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
router_logits_dtype: torch.dtype | None = None,
):
super().__init__()
@@ -559,6 +561,7 @@ class FusedMoE(CustomOp):
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
router_logits_dtype=router_logits_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
@@ -1509,7 +1512,9 @@ class FusedMoE(CustomOp):
)
self.batched_router_logits = torch.zeros(
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
logits_shape,
dtype=moe.router_logits_dtype,
device=torch.cuda.current_device(),
)
def select_experts(

View File

@@ -197,6 +197,7 @@ class Glm4MoE(nn.Module):
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: