From 8de7c636cc02a8306441af868b9c1d0e6d64799f Mon Sep 17 00:00:00 2001 From: Kevin McKay Date: Fri, 20 Feb 2026 00:25:46 -0600 Subject: [PATCH] [Bugfix][Hardware][AMD] Fix ROCM_AITER_FA speculative decoding support (#32877) Signed-off-by: c0de128 Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> --- vllm/v1/attention/backends/rocm_aiter_fa.py | 37 +++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 5ff450829..141d57d90 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -1076,10 +1076,43 @@ class AiterFlashAttentionImpl(AttentionImpl): # calculate for decodes if num_decodes > 0: assert attn_metadata.decode_metadata is not None - if self.sliding_window[0] != -1: + decode_max_query_len = attn_metadata.decode_metadata.max_query_len + + # Use unified_attention for speculative decoding (multi-token) + # or when sliding window is enabled + if self.sliding_window[0] != -1 or decode_max_query_len > 1: assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), ( - "Sliding window with shuffle layout is not supported yet." + "Shuffle KV cache layout is not supported with sliding " + "window or speculative decoding (multi-token decode)." ) + from aiter.ops.triton.unified_attention import ( + unified_attention, + ) + + descale_shape = ( + attn_metadata.query_start_loc[:num_decodes].shape[0] - 1, + key_cache.shape[2], + ) + unified_attention( + q=query[:num_decode_tokens], + k=key_cache, + v=value_cache, + out=output[:num_decode_tokens], + cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes], + max_seqlen_q=decode_max_query_len, + seqused_k=attn_metadata.seq_lens[:num_decodes], + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=attn_metadata.block_table[:num_decodes], + softcap=self.logits_soft_cap, + q_descale=None, + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): num_blocks, block_size, num_kv_heads, head_size = key_cache.shape