[Bugfix] Fix GLM-4 MoE router logits dtype for data parallel chunking (#31055)

Signed-off-by: ReinforcedKnowledge <reinforced.knowledge@gmail.com>
This commit is contained in:
Yakine Tahtah
2026-01-06 18:57:56 +01:00
committed by GitHub
parent 142c4d1738
commit 4e67a8f616
3 changed files with 13 additions and 1 deletions

View File

@@ -197,6 +197,7 @@ class Glm4MoE(nn.Module):
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: