From 89138b21cc246ae944c741d5c399c148e2b770ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elvir=20Crn=C4=8Devi=C4=87?= Date: Thu, 19 Mar 2026 01:28:37 +0100 Subject: [PATCH] [Bugfix] Zero-init MLA attention output buffers to prevent NaN from CUDA graph padding (#37442) Signed-off-by: Elvir Crncevic Signed-off-by: Matthew Bonanni Co-authored-by: Matthew Bonanni (cherry picked from commit ef2c4f778df5aa07a44e663330e2dfdc16927d2a) --- vllm/v1/attention/backends/mla/cutlass_mla.py | 15 ++++++- .../attention/backends/mla/flashinfer_mla.py | 44 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 19faf3c93..8fee72a1e 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -162,6 +162,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): # Share workspace buffer across all executions self._workspace = g_sm100_workspace + # Pre-allocated output buffer, lazily sized on first call. + # Zero-init once to prevent NaN in padding slots (seq_lens=0) + # from contaminating downstream per-tensor reductions. + self._decode_out: torch.Tensor | None = None + def _sm100_cutlass_mla_decode( self, q_nope: torch.Tensor, @@ -218,7 +223,15 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): if is_quantized_kv_cache(self.kv_cache_dtype) else q_nope.dtype ) - out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) + # Reuse pre-allocated zero-init output buffer to avoid a memset + # kernel on every CUDA graph replay. + if ( + self._decode_out is None + or self._decode_out.shape[0] < B_q + or self._decode_out.dtype != dtype + ): + self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype) + out = self._decode_out[:B_q] lse = ( torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) if self.need_to_return_lse_for_decode diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index ec8f4e640..0df182873 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -21,6 +21,7 @@ from vllm.v1.attention.backend import ( AttentionLayer, AttentionType, MultipleOf, + is_quantized_kv_cache, ) from vllm.v1.attention.backends.utils import KVCacheLayoutType @@ -151,6 +152,11 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None + # Pre-allocated output buffer, lazily sized on first call. + # Zero-init once to prevent NaN in padding slots (seq_lens=0) + # from contaminating downstream per-tensor reductions. + self._decode_out: torch.Tensor | None = None + def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], @@ -181,6 +187,37 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): if self.bmm2_scale is None: self.bmm2_scale = layer._v_scale_float + # Reuse pre-allocated zero-init output buffer to avoid a memset + # kernel on every CUDA graph replay. + # q is 4D: (batch, q_len_per_req, num_heads, head_dim) + # FlashInfer has a bug where out= validation hardcodes 3D shape + # (batch, num_heads, kv_lora_rank), but the kernel writes 4D + # (batch, q_len, num_heads, kv_lora_rank) when q_len > 1. + # So we can only pass out= for single-token decode (q_len == 1). + # For q_len > 1, we zero padding slots after the kernel returns. + # TODO: upstream fix to FlashInfer + B, q_len_per_req = q.shape[0], q.shape[1] + out_kwargs: dict[str, torch.Tensor] = {} + if q_len_per_req == 1: + dtype = ( + torch.bfloat16 + if is_quantized_kv_cache(self.kv_cache_dtype) + else q.dtype + ) + if ( + self._decode_out is None + or self._decode_out.shape[0] < B + or self._decode_out.dtype != dtype + ): + self._decode_out = torch.zeros( + B, + q.shape[2], + self.kv_lora_rank, + dtype=dtype, + device=q.device, + ) + out_kwargs["out"] = self._decode_out[:B] + o = trtllm_batch_decode_with_kv_cache_mla( query=q, kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), @@ -193,8 +230,15 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): max_seq_len=attn_metadata.max_seq_len, bmm1_scale=self.bmm1_scale, bmm2_scale=self.bmm2_scale, + **out_kwargs, ) + # For q_len > 1, we can't pass out= so we work around by zeroing padding slots + if not out_kwargs: + num_real = attn_metadata.num_decodes + if num_real < o.shape[0]: + o[num_real:] = 0 + # Flatten the output for consistent shape o = o.view(-1, o.shape[-2], o.shape[-1])