From 367cf5cd3eb234e0e191a6d7883bc66f54f42f79 Mon Sep 17 00:00:00 2001 From: Dimitrios Bariamis Date: Sat, 14 Mar 2026 00:41:16 +0100 Subject: [PATCH] [Feat][Bugfix] Enable additional dimension for Flashinfer MLA and fix routing dtype (#36931) Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- vllm/model_executor/models/deepseek_v2.py | 17 +++++++++++++++-- .../v1/attention/backends/mla/flashinfer_mla.py | 6 +++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a198f1a0b..f31e9ac3e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -47,7 +47,11 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE +from vllm.model_executor.layers.fused_moe import ( + GateLinear, + RoutingMethodType, + SharedFusedMoE, +) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -333,8 +337,12 @@ class DeepseekV2MoE(nn.Module): # NOTE(rob): this is a hack until we finish off the PR for # merging TRTLLM kernels into the MK framework. Then we can # query the MonolithicMK for the expected router logits. + # NOTE(dbari): Use BF16 if routing is not Deepseek, e.g. Mistral Large 3 self.gate.set_out_dtype( - torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16 + torch.float32 + if self.experts.quant_method.is_monolithic + and self.experts.routing_method_type == RoutingMethodType.DeepSeekV3 + else torch.bfloat16 ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -1197,6 +1205,11 @@ class DeepseekV2Model(nn.Module): if inputs_embeds is not None: hidden_states = inputs_embeds else: + if input_ids is None: + raise ValueError( + "Either input_ids or inputs_embeds must be provided " + "to DeepseekV2Model.forward" + ) hidden_states = self.embed_input_ids(input_ids) residual = None else: diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 102d5706b..86852534a 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -75,16 +75,16 @@ class FlashInferMLABackend(MLACommonBackend): use_sparse: bool, device_capability: DeviceCapability, ) -> str | None: - # FlashInfer MLA kernel requires qk_nope_head_dim == 128 + # FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128] from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() if vllm_config.model_config is not None: hf_text_config = vllm_config.model_config.hf_text_config qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) - if qk_nope_head_dim != 128: + if qk_nope_head_dim not in [64, 128]: return ( - f"FlashInfer MLA kernel requires qk_nope_head_dim == 128, " + f"FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128], " f"but got {qk_nope_head_dim}" ) return None