[Bugfix][Hardware][AMD] Fix ROCM_AITER_FA speculative decoding support (#32877)
Signed-off-by: c0de128 <kevin.mckay@outlook.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
@@ -1076,10 +1076,43 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
# calculate for decodes
|
# calculate for decodes
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
assert attn_metadata.decode_metadata is not None
|
assert attn_metadata.decode_metadata is not None
|
||||||
if self.sliding_window[0] != -1:
|
decode_max_query_len = attn_metadata.decode_metadata.max_query_len
|
||||||
|
|
||||||
|
# Use unified_attention for speculative decoding (multi-token)
|
||||||
|
# or when sliding window is enabled
|
||||||
|
if self.sliding_window[0] != -1 or decode_max_query_len > 1:
|
||||||
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
|
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
|
||||||
"Sliding window with shuffle layout is not supported yet."
|
"Shuffle KV cache layout is not supported with sliding "
|
||||||
|
"window or speculative decoding (multi-token decode)."
|
||||||
)
|
)
|
||||||
|
from aiter.ops.triton.unified_attention import (
|
||||||
|
unified_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
descale_shape = (
|
||||||
|
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
|
||||||
|
key_cache.shape[2],
|
||||||
|
)
|
||||||
|
unified_attention(
|
||||||
|
q=query[:num_decode_tokens],
|
||||||
|
k=key_cache,
|
||||||
|
v=value_cache,
|
||||||
|
out=output[:num_decode_tokens],
|
||||||
|
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
|
||||||
|
max_seqlen_q=decode_max_query_len,
|
||||||
|
seqused_k=attn_metadata.seq_lens[:num_decodes],
|
||||||
|
max_seqlen_k=attn_metadata.max_seq_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
block_table=attn_metadata.block_table[:num_decodes],
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
q_descale=None,
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||||
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
|
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
|
||||||
|
|||||||
Reference in New Issue
Block a user