[Bugfix] Define router_logits_dtype for remaining MoE models (#33737)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -174,6 +174,7 @@ class MiMoV2MoE(nn.Module):
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
scoring_func="sigmoid",
|
||||
router_logits_dtype=self.gate_dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user