[ROCm][perf] Shuffle KV cache to use paged_attention_common (#32914)

Signed-off-by: Samu Tamminen <stammine@amd.com>
Co-authored-by: Tuukka Sarvi <tuukka.sarvi@amd.com>
This commit is contained in:
Samu Tamminen
2026-04-01 06:30:19 +03:00
committed by GitHub
parent cb0b443274
commit c49497726b
2 changed files with 89 additions and 4 deletions

View File

@@ -2070,5 +2070,56 @@ class rocm_aiter_ops:
out_=out_,
)
@staticmethod
def paged_attention_common(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
tmp_out: torch.Tensor,
max_logits: torch.Tensor,
exp_sums: torch.Tensor,
max_seq_len: int,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_tables_stride0: int,
scale: float,
K_QScale_hip: torch.Tensor,
V_QScale_hip: torch.Tensor,
K_QScale_asm: torch.Tensor,
V_QScale_asm: torch.Tensor,
out_: torch.Tensor,
kv_cache_dtype: str,
):
"""
Paged attention common function.
This function is NOT wrapped with @is_aiter_supported decorator
to allow explicit backend selection via attention_config to work
even when VLLM_ROCM_USE_AITER=0.
Note: This performs lazy import of aiter.paged_attention_common
"""
from aiter import paged_attention_common
return paged_attention_common(
Q=Q,
K=K,
V=V,
tmp_out=tmp_out,
max_logits=max_logits,
exp_sums=exp_sums,
max_seq_len=max_seq_len,
block_tables=block_tables,
context_lens=context_lens,
block_tables_stride0=block_tables_stride0,
scale=scale,
K_QScale_hip=K_QScale_hip,
V_QScale_hip=V_QScale_hip,
K_QScale_asm=K_QScale_asm,
V_QScale_asm=V_QScale_asm,
out_=out_,
kv_cache_dtype=kv_cache_dtype,
)
rocm_aiter_ops.register_ops_once()

View File

@@ -1247,7 +1247,23 @@ class AiterFlashAttentionImpl(AttentionImpl):
v_descale=layer._v_scale.expand(descale_shape),
)
elif rocm_aiter_ops.is_shuffle_kv_cache_enabled():
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
_, num_heads, head_size = query.shape
num_seqs = attn_metadata.seq_lens.shape[0]
max_num_partitions = (
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
tmp_out = torch.empty(
(num_seqs, num_heads, max_num_partitions, head_size),
dtype=query.dtype,
device=query.device,
)
exp_sums = torch.empty(
(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=query.device,
)
max_logits = torch.empty_like(exp_sums)
num_blocks, block_size, num_kv_heads, _ = key_cache.shape
x = 16 // key_cache.element_size()
k_cache_template = torch.empty(
[num_blocks, num_kv_heads, head_size // x, block_size, x],
@@ -1261,18 +1277,36 @@ class AiterFlashAttentionImpl(AttentionImpl):
)
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
rocm_aiter_ops.pa_fwd_asm(
k_qscale = (
layer._k_scale
if attn_metadata.k_scale is None
else attn_metadata.k_scale
)
v_qscale = (
layer._v_scale
if attn_metadata.v_scale is None
else attn_metadata.v_scale
)
rocm_aiter_ops.paged_attention_common(
Q=query[:num_decode_tokens],
K=new_key_cache,
V=new_value_cache,
tmp_out=tmp_out,
max_logits=max_logits,
exp_sums=exp_sums,
max_seq_len=attn_metadata.max_seq_len,
block_tables=attn_metadata.block_table[:num_decodes],
context_lens=attn_metadata.seq_lens[:num_decodes],
block_tables_stride0=attn_metadata.block_table[
:num_decodes
].stride(0),
K_QScale=attn_metadata.k_scale,
V_QScale=attn_metadata.v_scale,
scale=self.scale,
K_QScale_hip=k_qscale,
V_QScale_hip=v_qscale,
K_QScale_asm=k_qscale,
V_QScale_asm=v_qscale,
out_=output[:num_decode_tokens],
kv_cache_dtype=self.kv_cache_dtype,
)
else:
_, num_heads, head_size = query.shape