[Feat][Bugfix] Enable additional dimension for Flashinfer MLA and fix routing dtype (#36931)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
This commit is contained in:
Dimitrios Bariamis
2026-03-14 00:41:16 +01:00
committed by GitHub
parent 6d53efd2a5
commit 367cf5cd3e
2 changed files with 18 additions and 5 deletions

View File

@@ -47,7 +47,11 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE
from vllm.model_executor.layers.fused_moe import (
GateLinear,
RoutingMethodType,
SharedFusedMoE,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -333,8 +337,12 @@ class DeepseekV2MoE(nn.Module):
# NOTE(rob): this is a hack until we finish off the PR for
# merging TRTLLM kernels into the MK framework. Then we can
# query the MonolithicMK for the expected router logits.
# NOTE(dbari): Use BF16 if routing is not Deepseek, e.g. Mistral Large 3
self.gate.set_out_dtype(
torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
torch.float32
if self.experts.quant_method.is_monolithic
and self.experts.routing_method_type == RoutingMethodType.DeepSeekV3
else torch.bfloat16
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -1197,6 +1205,11 @@ class DeepseekV2Model(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
if input_ids is None:
raise ValueError(
"Either input_ids or inputs_embeds must be provided "
"to DeepseekV2Model.forward"
)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else: