[Bugfix] Zero-init MLA attention output buffers to prevent NaN from CUDA graph padding (#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
(cherry picked from commit ef2c4f778d)
This commit is contained in:
Elvir Crnčević
2026-03-19 01:28:37 +01:00
committed by khluu
parent 6edd43de3c
commit 89138b21cc
2 changed files with 58 additions and 1 deletions

View File

@@ -162,6 +162,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
# Share workspace buffer across all executions # Share workspace buffer across all executions
self._workspace = g_sm100_workspace self._workspace = g_sm100_workspace
# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None
def _sm100_cutlass_mla_decode( def _sm100_cutlass_mla_decode(
self, self,
q_nope: torch.Tensor, q_nope: torch.Tensor,
@@ -218,7 +223,15 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
if is_quantized_kv_cache(self.kv_cache_dtype) if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype else q_nope.dtype
) )
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) # Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
if (
self._decode_out is None
or self._decode_out.shape[0] < B_q
or self._decode_out.dtype != dtype
):
self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
out = self._decode_out[:B_q]
lse = ( lse = (
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
if self.need_to_return_lse_for_decode if self.need_to_return_lse_for_decode

View File

@@ -21,6 +21,7 @@ from vllm.v1.attention.backend import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.attention.backends.utils import KVCacheLayoutType
@@ -151,6 +152,11 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
self.bmm1_scale: float | None = None self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None self.bmm2_scale: float | None = None
# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None
def forward_mqa( def forward_mqa(
self, self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
@@ -181,6 +187,37 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float self.bmm2_scale = layer._v_scale_float
# Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
# q is 4D: (batch, q_len_per_req, num_heads, head_dim)
# FlashInfer has a bug where out= validation hardcodes 3D shape
# (batch, num_heads, kv_lora_rank), but the kernel writes 4D
# (batch, q_len, num_heads, kv_lora_rank) when q_len > 1.
# So we can only pass out= for single-token decode (q_len == 1).
# For q_len > 1, we zero padding slots after the kernel returns.
# TODO: upstream fix to FlashInfer
B, q_len_per_req = q.shape[0], q.shape[1]
out_kwargs: dict[str, torch.Tensor] = {}
if q_len_per_req == 1:
dtype = (
torch.bfloat16
if is_quantized_kv_cache(self.kv_cache_dtype)
else q.dtype
)
if (
self._decode_out is None
or self._decode_out.shape[0] < B
or self._decode_out.dtype != dtype
):
self._decode_out = torch.zeros(
B,
q.shape[2],
self.kv_lora_rank,
dtype=dtype,
device=q.device,
)
out_kwargs["out"] = self._decode_out[:B]
o = trtllm_batch_decode_with_kv_cache_mla( o = trtllm_batch_decode_with_kv_cache_mla(
query=q, query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
@@ -193,8 +230,15 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
max_seq_len=attn_metadata.max_seq_len, max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale, bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale, bmm2_scale=self.bmm2_scale,
**out_kwargs,
) )
# For q_len > 1, we can't pass out= so we work around by zeroing padding slots
if not out_kwargs:
num_real = attn_metadata.num_decodes
if num_real < o.shape[0]:
o[num_real:] = 0
# Flatten the output for consistent shape # Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1]) o = o.view(-1, o.shape[-2], o.shape[-1])