Use paged_attention_v1 for sliding window decode in rocm_aiter_fa (#34378)
Signed-off-by: Martin Yuan <myuan@meta.com> Co-authored-by: Martin Yuan <myuan@meta.com>
This commit is contained in:
committed by
GitHub
parent
f120bd42d3
commit
9ea1f598ce
@@ -1075,35 +1075,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
|
||||
"Sliding window with shuffle layout is not supported yet."
|
||||
)
|
||||
from aiter.ops.triton.unified_attention import (
|
||||
unified_attention,
|
||||
)
|
||||
|
||||
descale_shape = (
|
||||
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
|
||||
key_cache.shape[2],
|
||||
)
|
||||
unified_attention(
|
||||
q=query[:num_decode_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_decode_tokens],
|
||||
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
|
||||
max_seqlen_q=1, # optimize this
|
||||
seqused_k=attn_metadata.seq_lens[:num_decodes],
|
||||
max_seqlen_k=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=attn_metadata.block_table[:num_decodes],
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None,
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
|
||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
@@ -1172,6 +1143,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
layer._v_scale,
|
||||
None,
|
||||
_PARTITION_SIZE_ROCM,
|
||||
1,
|
||||
self.sliding_window[0] + 1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
||||
Reference in New Issue
Block a user