Fix CUDA graph decode capture crash in AITER FlashAttention (#36042)
Signed-off-by: Martin Yuan <myuan@meta.com> Co-authored-by: Martin Yuan <myuan@meta.com>
This commit is contained in:
committed by
GitHub
parent
7eb524e64c
commit
1a9718085c
@@ -1152,11 +1152,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
decode_max_query_len = attn_metadata.decode_metadata.max_query_len
|
||||
|
||||
# Use unified_attention for speculative decoding (multi-token)
|
||||
# or when sliding window is enabled
|
||||
if self.sliding_window[0] != -1 or decode_max_query_len > 1:
|
||||
if decode_max_query_len > 1:
|
||||
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
|
||||
"Shuffle KV cache layout is not supported with sliding "
|
||||
"window or speculative decoding (multi-token decode)."
|
||||
"Shuffle KV cache layout is not supported with "
|
||||
"speculative decoding (multi-token decode)."
|
||||
)
|
||||
from aiter.ops.triton.unified_attention import (
|
||||
unified_attention,
|
||||
|
||||
Reference in New Issue
Block a user