[Bugfix] Define router_logits_dtype for remaining MoE models (#33737)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-02-04 00:24:14 -05:00
committed by GitHub
parent 2647163674
commit eb5ed20743
6 changed files with 9 additions and 4 deletions

View File

@@ -300,6 +300,7 @@ class BailingMoE(nn.Module):
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
router_logits_dtype=self.router_dtype,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: