[Bugfix] Zero-init MLA attention output buffers to prevent NaN from CUDA graph padding (#37442)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
(cherry picked from commit ef2c4f778d)
This commit is contained in:
@@ -162,6 +162,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
# Share workspace buffer across all executions
|
||||
self._workspace = g_sm100_workspace
|
||||
|
||||
# Pre-allocated output buffer, lazily sized on first call.
|
||||
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
|
||||
# from contaminating downstream per-tensor reductions.
|
||||
self._decode_out: torch.Tensor | None = None
|
||||
|
||||
def _sm100_cutlass_mla_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
@@ -218,7 +223,15 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
else q_nope.dtype
|
||||
)
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||
# Reuse pre-allocated zero-init output buffer to avoid a memset
|
||||
# kernel on every CUDA graph replay.
|
||||
if (
|
||||
self._decode_out is None
|
||||
or self._decode_out.shape[0] < B_q
|
||||
or self._decode_out.dtype != dtype
|
||||
):
|
||||
self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||
out = self._decode_out[:B_q]
|
||||
lse = (
|
||||
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
|
||||
if self.need_to_return_lse_for_decode
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
|
||||
@@ -151,6 +152,11 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
# Pre-allocated output buffer, lazily sized on first call.
|
||||
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
|
||||
# from contaminating downstream per-tensor reductions.
|
||||
self._decode_out: torch.Tensor | None = None
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -181,6 +187,37 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
# Reuse pre-allocated zero-init output buffer to avoid a memset
|
||||
# kernel on every CUDA graph replay.
|
||||
# q is 4D: (batch, q_len_per_req, num_heads, head_dim)
|
||||
# FlashInfer has a bug where out= validation hardcodes 3D shape
|
||||
# (batch, num_heads, kv_lora_rank), but the kernel writes 4D
|
||||
# (batch, q_len, num_heads, kv_lora_rank) when q_len > 1.
|
||||
# So we can only pass out= for single-token decode (q_len == 1).
|
||||
# For q_len > 1, we zero padding slots after the kernel returns.
|
||||
# TODO: upstream fix to FlashInfer
|
||||
B, q_len_per_req = q.shape[0], q.shape[1]
|
||||
out_kwargs: dict[str, torch.Tensor] = {}
|
||||
if q_len_per_req == 1:
|
||||
dtype = (
|
||||
torch.bfloat16
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
else q.dtype
|
||||
)
|
||||
if (
|
||||
self._decode_out is None
|
||||
or self._decode_out.shape[0] < B
|
||||
or self._decode_out.dtype != dtype
|
||||
):
|
||||
self._decode_out = torch.zeros(
|
||||
B,
|
||||
q.shape[2],
|
||||
self.kv_lora_rank,
|
||||
dtype=dtype,
|
||||
device=q.device,
|
||||
)
|
||||
out_kwargs["out"] = self._decode_out[:B]
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
@@ -193,8 +230,15 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
**out_kwargs,
|
||||
)
|
||||
|
||||
# For q_len > 1, we can't pass out= so we work around by zeroing padding slots
|
||||
if not out_kwargs:
|
||||
num_real = attn_metadata.num_decodes
|
||||
if num_real < o.shape[0]:
|
||||
o[num_real:] = 0
|
||||
|
||||
# Flatten the output for consistent shape
|
||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user