diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py index 6f5d7d766..9b3d9fb22 100644 --- a/vllm/model_executor/models/afmoe.py +++ b/vllm/model_executor/models/afmoe.py @@ -142,6 +142,7 @@ class AfmoeMoE(nn.Module): e_score_correction_bias=self.expert_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, + router_logits_dtype=torch.float32, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index fc10f790e..7725dfa2a 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -300,6 +300,7 @@ class BailingMoE(nn.Module): num_expert_group=self.n_group, topk_group=self.topk_group, use_grouped_topk=self.use_grouped_topk, + router_logits_dtype=self.router_dtype, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/flex_olmo.py b/vllm/model_executor/models/flex_olmo.py index 11d0949a7..a2e2adc2a 100644 --- a/vllm/model_executor/models/flex_olmo.py +++ b/vllm/model_executor/models/flex_olmo.py @@ -71,7 +71,6 @@ class FlexOlmoMoE(nn.Module): prefix=f"{prefix}.gate", ) - # Gate always runs at half / full precision for now. self.experts = FusedMoE( num_experts=hf_config.num_experts, top_k=hf_config.num_experts_per_tok, @@ -82,6 +81,7 @@ class FlexOlmoMoE(nn.Module): quant_config=None, tp_size=tp_size, prefix=f"{prefix}.experts", + router_logits_dtype=torch.float32, ) self.top_k = hf_config.num_experts_per_tok diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index f8b426df0..32408e7c3 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -236,9 +236,9 @@ class FlashMLP(nn.Module): class LongcatRouter(nn.Module): def __init__( self, - config, - zero_expert_num=0, - rounter_params_dtype=torch.bfloat16, + config: FlashConfig, + zero_expert_num: int, + rounter_params_dtype: torch.dtype, prefix: str = "", ): super().__init__() @@ -309,6 +309,7 @@ class LongcatMoe(nn.Module): prefix=f"{prefix}.experts", enable_eplb=enable_eplb, routed_scaling_factor=config.routed_scaling_factor, + router_logits_dtype=self.rounter_params_dtype, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/mimo_v2_flash.py b/vllm/model_executor/models/mimo_v2_flash.py index f7640746a..f74ce59ab 100644 --- a/vllm/model_executor/models/mimo_v2_flash.py +++ b/vllm/model_executor/models/mimo_v2_flash.py @@ -174,6 +174,7 @@ class MiMoV2MoE(nn.Module): num_expert_group=config.n_group, topk_group=config.topk_group, scoring_func="sigmoid", + router_logits_dtype=self.gate_dtype, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index f0d7b4a75..8019dbdbe 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -388,6 +388,7 @@ class FusedMoEBlock(nn.Module): routed_scaling_factor=config.moe_router_scaling_factor, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, + router_logits_dtype=torch.float32, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: