[MoE] Fix output_shape calculation in Attention layer to handle 3D query inputs (#31596)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-01-02 09:46:23 -06:00
committed by GitHub
parent bd877162eb
commit 6ef770df7c
2 changed files with 17 additions and 2 deletions

View File

@@ -357,8 +357,11 @@ class Attention(nn.Module, AttentionLayerBase):
if self.use_output: if self.use_output:
if output_shape is None: if output_shape is None:
# Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query
num_tokens = query.shape[0]
output_shape = torch.Size( output_shape = torch.Size(
(*query.shape[:-1], self.num_heads * self.head_size_v) (num_tokens, self.num_heads * self.head_size_v)
) )
output_shape = output_shape if output_shape is not None else query.shape output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device) output = torch.empty(output_shape, dtype=output_dtype, device=query.device)

View File

@@ -180,7 +180,19 @@ def get_fp8_moe_backend(
scope="local", scope="local",
) )
if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant: # Determine if we should use DeepGEMM (top-level enable switch)
# - If explicitly set by user, respect their choice
# - If not platform supports DeepGEMM, disable it
# This helps avoid warning messages on unsupported platforms.
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
if not is_deep_gemm_supported():
use_deep_gemm = False
logger.info_once(
"DeepGEMM is disabled because the platform does not support it.",
scope="local",
)
if use_deep_gemm and moe_use_deep_gemm and block_quant:
if not has_deep_gemm(): if not has_deep_gemm():
logger.warning_once( logger.warning_once(
"DeepGEMM backend requested but not available.", scope="local" "DeepGEMM backend requested but not available.", scope="local"