[ROCm][perf] Shuffle KV cache to use paged_attention_common (#32914)
Signed-off-by: Samu Tamminen <stammine@amd.com> Co-authored-by: Tuukka Sarvi <tuukka.sarvi@amd.com>
This commit is contained in:
@@ -2070,5 +2070,56 @@ class rocm_aiter_ops:
|
||||
out_=out_,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def paged_attention_common(
|
||||
Q: torch.Tensor,
|
||||
K: torch.Tensor,
|
||||
V: torch.Tensor,
|
||||
tmp_out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
exp_sums: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables_stride0: int,
|
||||
scale: float,
|
||||
K_QScale_hip: torch.Tensor,
|
||||
V_QScale_hip: torch.Tensor,
|
||||
K_QScale_asm: torch.Tensor,
|
||||
V_QScale_asm: torch.Tensor,
|
||||
out_: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
):
|
||||
"""
|
||||
Paged attention common function.
|
||||
|
||||
This function is NOT wrapped with @is_aiter_supported decorator
|
||||
to allow explicit backend selection via attention_config to work
|
||||
even when VLLM_ROCM_USE_AITER=0.
|
||||
|
||||
Note: This performs lazy import of aiter.paged_attention_common
|
||||
"""
|
||||
from aiter import paged_attention_common
|
||||
|
||||
return paged_attention_common(
|
||||
Q=Q,
|
||||
K=K,
|
||||
V=V,
|
||||
tmp_out=tmp_out,
|
||||
max_logits=max_logits,
|
||||
exp_sums=exp_sums,
|
||||
max_seq_len=max_seq_len,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
block_tables_stride0=block_tables_stride0,
|
||||
scale=scale,
|
||||
K_QScale_hip=K_QScale_hip,
|
||||
V_QScale_hip=V_QScale_hip,
|
||||
K_QScale_asm=K_QScale_asm,
|
||||
V_QScale_asm=V_QScale_asm,
|
||||
out_=out_,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
rocm_aiter_ops.register_ops_once()
|
||||
|
||||
@@ -1247,7 +1247,23 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
elif rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
_, num_heads, head_size = query.shape
|
||||
num_seqs = attn_metadata.seq_lens.shape[0]
|
||||
max_num_partitions = (
|
||||
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
|
||||
) // _PARTITION_SIZE_ROCM
|
||||
tmp_out = torch.empty(
|
||||
(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=query.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
num_blocks, block_size, num_kv_heads, _ = key_cache.shape
|
||||
x = 16 // key_cache.element_size()
|
||||
k_cache_template = torch.empty(
|
||||
[num_blocks, num_kv_heads, head_size // x, block_size, x],
|
||||
@@ -1261,18 +1277,36 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
new_key_cache = key_cache.view_as(k_cache_template)
|
||||
new_value_cache = value_cache.view_as(v_cache_template)
|
||||
rocm_aiter_ops.pa_fwd_asm(
|
||||
k_qscale = (
|
||||
layer._k_scale
|
||||
if attn_metadata.k_scale is None
|
||||
else attn_metadata.k_scale
|
||||
)
|
||||
v_qscale = (
|
||||
layer._v_scale
|
||||
if attn_metadata.v_scale is None
|
||||
else attn_metadata.v_scale
|
||||
)
|
||||
rocm_aiter_ops.paged_attention_common(
|
||||
Q=query[:num_decode_tokens],
|
||||
K=new_key_cache,
|
||||
V=new_value_cache,
|
||||
tmp_out=tmp_out,
|
||||
max_logits=max_logits,
|
||||
exp_sums=exp_sums,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
block_tables=attn_metadata.block_table[:num_decodes],
|
||||
context_lens=attn_metadata.seq_lens[:num_decodes],
|
||||
block_tables_stride0=attn_metadata.block_table[
|
||||
:num_decodes
|
||||
].stride(0),
|
||||
K_QScale=attn_metadata.k_scale,
|
||||
V_QScale=attn_metadata.v_scale,
|
||||
scale=self.scale,
|
||||
K_QScale_hip=k_qscale,
|
||||
V_QScale_hip=v_qscale,
|
||||
K_QScale_asm=k_qscale,
|
||||
V_QScale_asm=v_qscale,
|
||||
out_=output[:num_decode_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
_, num_heads, head_size = query.shape
|
||||
|
||||
Reference in New Issue
Block a user