[Bugfix][Model]fix ernie45 moe gate&bias dtype to float32 (#25936)

Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
CSWYF3634076
2025-09-30 19:11:21 +08:00
committed by GitHub
parent 1ad3aca682
commit ef6e0e7132
2 changed files with 13 additions and 7 deletions

View File

@@ -120,11 +120,12 @@ class Ernie4_5_MoeMoE(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size,
config.moe_num_experts,
bias=False,
params_dtype=torch.float32,
quant_config=None,
prefix=f"{prefix}.gate")
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.moe_num_experts))
torch.empty(config.moe_num_experts, dtype=torch.float32))
self.experts = FusedMoE(
num_experts=config.moe_num_experts,
@@ -157,7 +158,7 @@ class Ernie4_5_MoeMoE(nn.Module):
if self.has_shared_experts:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)