From 6ef770df7c3f0d135c2f3a594c461949113aae91 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Fri, 2 Jan 2026 09:46:23 -0600 Subject: [PATCH] [MoE] Fix output_shape calculation in Attention layer to handle 3D query inputs (#31596) Signed-off-by: Andreas Karatzas --- vllm/attention/layer.py | 5 ++++- vllm/model_executor/layers/quantization/fp8.py | 14 +++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a88544c1c..a09666b65 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -357,8 +357,11 @@ class Attention(nn.Module, AttentionLayerBase): if self.use_output: if output_shape is None: + # Handle both 2D [num_tokens, hidden] and + # 3D [num_tokens, heads, head_dim] query + num_tokens = query.shape[0] output_shape = torch.Size( - (*query.shape[:-1], self.num_heads * self.head_size_v) + (num_tokens, self.num_heads * self.head_size_v) ) output_shape = output_shape if output_shape is not None else query.shape output = torch.empty(output_shape, dtype=output_dtype, device=query.device) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 08e1f4d44..3dcd9a84a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -180,7 +180,19 @@ def get_fp8_moe_backend( scope="local", ) - if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant: + # Determine if we should use DeepGEMM (top-level enable switch) + # - If explicitly set by user, respect their choice + # - If not platform supports DeepGEMM, disable it + # This helps avoid warning messages on unsupported platforms. + use_deep_gemm = envs.VLLM_USE_DEEP_GEMM + if not is_deep_gemm_supported(): + use_deep_gemm = False + logger.info_once( + "DeepGEMM is disabled because the platform does not support it.", + scope="local", + ) + + if use_deep_gemm and moe_use_deep_gemm and block_quant: if not has_deep_gemm(): logger.warning_once( "DeepGEMM backend requested but not available.", scope="local"