diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 3f298f7a5..17d5ec4bc 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1006,6 +1006,9 @@ class FusedMoEConfig: # The activation type. in_dtype: torch.dtype + # Defaults to in_dtype if not specified. + router_logits_dtype: torch.dtype | None = None + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE has_bias: bool = False @@ -1022,6 +1025,9 @@ class FusedMoEConfig: assert self.max_num_tokens > 0 + if self.router_logits_dtype is None: + self.router_logits_dtype = self.in_dtype + @property def tp_size(self): return self.moe_parallel_config.tp_size diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 374dffde5..09cefa74c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -314,6 +314,7 @@ class FusedMoE(CustomOp): renormalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. enable_eplb: Whether to enable expert parallelism load balancer. + router_logits_dtype: Data type for router logits buffers. """ def __init__( @@ -348,6 +349,7 @@ class FusedMoE(CustomOp): expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, routing_method_type: int | None = None, + router_logits_dtype: torch.dtype | None = None, ): super().__init__() @@ -559,6 +561,7 @@ class FusedMoE(CustomOp): num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=moe_in_dtype, + router_logits_dtype=router_logits_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, is_act_and_mul=is_act_and_mul, @@ -1509,7 +1512,9 @@ class FusedMoE(CustomOp): ) self.batched_router_logits = torch.zeros( - logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() + logits_shape, + dtype=moe.router_logits_dtype, + device=torch.cuda.current_device(), ) def select_experts( diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 6fb09be7c..a1115024a 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -197,6 +197,7 @@ class Glm4MoE(nn.Module): e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, + router_logits_dtype=torch.float32, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: