[Bugfix] Zero-init MLA attention output buffers to prevent NaN from CUDA graph padding (#37442)
Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
(cherry picked from commit ef2c4f778d)
This commit is contained in:
@@ -162,6 +162,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
# Share workspace buffer across all executions
|
# Share workspace buffer across all executions
|
||||||
self._workspace = g_sm100_workspace
|
self._workspace = g_sm100_workspace
|
||||||
|
|
||||||
|
# Pre-allocated output buffer, lazily sized on first call.
|
||||||
|
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
|
||||||
|
# from contaminating downstream per-tensor reductions.
|
||||||
|
self._decode_out: torch.Tensor | None = None
|
||||||
|
|
||||||
def _sm100_cutlass_mla_decode(
|
def _sm100_cutlass_mla_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
@@ -218,7 +223,15 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
if is_quantized_kv_cache(self.kv_cache_dtype)
|
if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
else q_nope.dtype
|
else q_nope.dtype
|
||||||
)
|
)
|
||||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
# Reuse pre-allocated zero-init output buffer to avoid a memset
|
||||||
|
# kernel on every CUDA graph replay.
|
||||||
|
if (
|
||||||
|
self._decode_out is None
|
||||||
|
or self._decode_out.shape[0] < B_q
|
||||||
|
or self._decode_out.dtype != dtype
|
||||||
|
):
|
||||||
|
self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||||
|
out = self._decode_out[:B_q]
|
||||||
lse = (
|
lse = (
|
||||||
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
|
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
|
||||||
if self.need_to_return_lse_for_decode
|
if self.need_to_return_lse_for_decode
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
MultipleOf,
|
MultipleOf,
|
||||||
|
is_quantized_kv_cache,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||||
|
|
||||||
@@ -151,6 +152,11 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
self.bmm1_scale: float | None = None
|
self.bmm1_scale: float | None = None
|
||||||
self.bmm2_scale: float | None = None
|
self.bmm2_scale: float | None = None
|
||||||
|
|
||||||
|
# Pre-allocated output buffer, lazily sized on first call.
|
||||||
|
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
|
||||||
|
# from contaminating downstream per-tensor reductions.
|
||||||
|
self._decode_out: torch.Tensor | None = None
|
||||||
|
|
||||||
def forward_mqa(
|
def forward_mqa(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||||
@@ -181,6 +187,37 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
if self.bmm2_scale is None:
|
if self.bmm2_scale is None:
|
||||||
self.bmm2_scale = layer._v_scale_float
|
self.bmm2_scale = layer._v_scale_float
|
||||||
|
|
||||||
|
# Reuse pre-allocated zero-init output buffer to avoid a memset
|
||||||
|
# kernel on every CUDA graph replay.
|
||||||
|
# q is 4D: (batch, q_len_per_req, num_heads, head_dim)
|
||||||
|
# FlashInfer has a bug where out= validation hardcodes 3D shape
|
||||||
|
# (batch, num_heads, kv_lora_rank), but the kernel writes 4D
|
||||||
|
# (batch, q_len, num_heads, kv_lora_rank) when q_len > 1.
|
||||||
|
# So we can only pass out= for single-token decode (q_len == 1).
|
||||||
|
# For q_len > 1, we zero padding slots after the kernel returns.
|
||||||
|
# TODO: upstream fix to FlashInfer
|
||||||
|
B, q_len_per_req = q.shape[0], q.shape[1]
|
||||||
|
out_kwargs: dict[str, torch.Tensor] = {}
|
||||||
|
if q_len_per_req == 1:
|
||||||
|
dtype = (
|
||||||
|
torch.bfloat16
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
|
else q.dtype
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self._decode_out is None
|
||||||
|
or self._decode_out.shape[0] < B
|
||||||
|
or self._decode_out.dtype != dtype
|
||||||
|
):
|
||||||
|
self._decode_out = torch.zeros(
|
||||||
|
B,
|
||||||
|
q.shape[2],
|
||||||
|
self.kv_lora_rank,
|
||||||
|
dtype=dtype,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
out_kwargs["out"] = self._decode_out[:B]
|
||||||
|
|
||||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||||
query=q,
|
query=q,
|
||||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||||
@@ -193,8 +230,15 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
max_seq_len=attn_metadata.max_seq_len,
|
max_seq_len=attn_metadata.max_seq_len,
|
||||||
bmm1_scale=self.bmm1_scale,
|
bmm1_scale=self.bmm1_scale,
|
||||||
bmm2_scale=self.bmm2_scale,
|
bmm2_scale=self.bmm2_scale,
|
||||||
|
**out_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For q_len > 1, we can't pass out= so we work around by zeroing padding slots
|
||||||
|
if not out_kwargs:
|
||||||
|
num_real = attn_metadata.num_decodes
|
||||||
|
if num_real < o.shape[0]:
|
||||||
|
o[num_real:] = 0
|
||||||
|
|
||||||
# Flatten the output for consistent shape
|
# Flatten the output for consistent shape
|
||||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user