Fix CUDA graph decode capture crash in AITER FlashAttention (#36042)

Signed-off-by: Martin Yuan <myuan@meta.com>
Co-authored-by: Martin Yuan <myuan@meta.com>
This commit is contained in:
Mengtao (Martin) Yuan
2026-03-06 18:12:07 -08:00
committed by GitHub
parent 7eb524e64c
commit 1a9718085c

View File

@@ -1152,11 +1152,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
decode_max_query_len = attn_metadata.decode_metadata.max_query_len
# Use unified_attention for speculative decoding (multi-token)
# or when sliding window is enabled
if self.sliding_window[0] != -1 or decode_max_query_len > 1:
if decode_max_query_len > 1:
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
"Shuffle KV cache layout is not supported with sliding "
"window or speculative decoding (multi-token decode)."
"Shuffle KV cache layout is not supported with "
"speculative decoding (multi-token decode)."
)
from aiter.ops.triton.unified_attention import (
unified_attention,