[ROCm][Perf] Enable shuffle kv cache layout and assembly paged attention kernel for AiterFlashAttentionBackend (#29887)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone
2026-01-15 23:29:53 +08:00
committed by GitHub
parent 361dfdc9d8
commit 130d6c9514
3 changed files with 309 additions and 82 deletions

View File

@@ -833,6 +833,7 @@ class rocm_aiter_ops:
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
_SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
# TODO: Consolidate under _LINEAR_ENABLED
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
@@ -859,6 +860,7 @@ class rocm_aiter_ops:
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
cls._SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
@@ -906,6 +908,11 @@ class rocm_aiter_ops:
def is_mha_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._MHA_ENABLED
@classmethod
@if_aiter_supported
def is_shuffle_kv_cache_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED
@classmethod
@if_aiter_supported
def is_triton_unified_attn_enabled(cls) -> bool:

View File

@@ -128,6 +128,7 @@ if TYPE_CHECKING:
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
@@ -1018,6 +1019,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: (
os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")
),
# Whether to use the shuffled kv cache layout
"VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: (
os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1")
),
# Custom quick allreduce kernel for MI3* cards
# Choice of quantization level: FP, INT8, INT6, INT4 or NONE
# Recommended for large models to get allreduce

View File

@@ -7,6 +7,7 @@ from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
@@ -30,7 +31,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
_PARTITION_SIZE_ROCM = 256
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if current_platform.is_rocm():
import aiter
@@ -52,7 +52,7 @@ if current_platform.is_rocm():
cu_seqlens_kv_ptr, # [num_batches + 1]
token_to_batch_ptr, # [max_cum_tokens]
seq_start_ptr, # [num_batches]
k_scale_ptr,
k_scale_ptr, # [1] / [num_blocks, num_kv_heads, page_size]
v_scale_ptr,
num_heads,
head_size,
@@ -64,13 +64,15 @@ if current_platform.is_rocm():
BLOCK_SIZE: tl.constexpr,
):
token_id = tl.program_id(0)
head_id = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
if DEQUANT:
k_scale = tl.load(k_scale_ptr)
v_scale = tl.load(v_scale_ptr)
key_ptr_offset = key_ptr + token_id * head_size * num_heads
value_ptr_offset = value_ptr + token_id * head_size * num_heads
key_ptr_offset = (
key_ptr + token_id * head_size * num_heads + head_id * head_size
)
value_ptr_offset = (
value_ptr + token_id * head_size * num_heads + head_id * head_size
)
batch_idx = tl.load(token_to_batch_ptr + token_id)
batch_start = tl.load(seq_start_ptr + batch_idx)
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
@@ -89,24 +91,54 @@ if current_platform.is_rocm():
key_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE
+ slot_id * num_heads * head_size
+ head_id * head_size
)
value_cache_ptr_offset = (
value_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE
+ slot_id * num_heads * head_size
+ head_id * head_size
)
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
mask = (col_offsets + i) < head_size * num_heads
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
k_reg = tl.load(key_cache_ptr_offset + col_offsets)
v_reg = tl.load(value_cache_ptr_offset + col_offsets)
if DEQUANT:
k_scale = tl.load(k_scale_ptr)
v_scale = tl.load(v_scale_ptr)
k_dtype = k_reg.dtype
v_dtype = v_reg.dtype
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
tl.store(key_ptr_offset + col_offsets, k_reg)
tl.store(value_ptr_offset + col_offsets, v_reg)
elif CACHE_FORMAT == "SHUFFLE":
# for kv cache layout as
# K: [num_blocks, num_head, head_dim // x, page_size, x]
# V: [num_blocks, num_head, page_size // x, head_dim, x]
key_cache_ptr_offset = (
key_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE
+ head_id * head_size * PAGE_SIZE
+ slot_id * x
)
value_cache_ptr_offset = (
value_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE
+ head_id * head_size * PAGE_SIZE
+ (slot_id // x) * head_size * x
+ slot_id % x
)
k_reg_offset = col_offsets // x * PAGE_SIZE * x + col_offsets % x
v_reg_offset = col_offsets * x
k_reg = tl.load(key_cache_ptr_offset + k_reg_offset)
v_reg = tl.load(value_cache_ptr_offset + v_reg_offset)
if DEQUANT:
k_scale = 1.0
v_scale = 1.0
k_reg = k_reg.to(tl.float32) * k_scale
v_reg = v_reg.to(tl.float32) * v_scale
tl.store(key_ptr_offset + col_offsets, k_reg)
tl.store(value_ptr_offset + col_offsets, v_reg)
def cp_mha_gather_cache(
key_cache: torch.Tensor,
@@ -123,17 +155,14 @@ if current_platform.is_rocm():
kv_cache_layout: str,
total_tokens: int,
):
assert kv_cache_layout in ["v0", "NHD", "HND"], (
"kv_cache_layout only support v0, NHD, HND"
assert kv_cache_layout in ["NHD", "SHUFFLE"], (
"kv_cache_layout only support NHD, SHUFFLE"
)
head_dim = key.shape[2]
x = 0
x = 16 // key_cache.element_size()
# assert dequant is True, "Currently, we only support "\
# "gather cache with dequant"
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
assert kv_cache_layout == "NHD", (
"ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
)
assert head_dim == key_cache.shape[3], (
"We assume your kv cache layout is [num_blocks, "
"page_size, num_heads, head_dim], but got otherwise"
@@ -141,7 +170,7 @@ if current_platform.is_rocm():
page_size = key_cache.shape[1]
num_heads = key_cache.shape[2]
grid = lambda meta: (total_tokens,)
grid = lambda meta: (total_tokens, num_heads)
cp_mha_gather_cache_kernel[grid](
key_cache,
value_cache,
@@ -163,6 +192,112 @@ if current_platform.is_rocm():
BLOCK_SIZE=head_dim,
)
@triton.jit
def reshape_and_cache_shuffle_kernel(
key_ptr, # [num_tokens, num_kv_heads, head_size]
value_ptr, # [num_tokens, num_kv_heads, head_size]
key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x]
value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x]
slot_mapping_ptr, # [num_tokens]
k_scale_ptr, # [num_blocks, num_kv_heads, block_size]
v_scale_ptr, # [num_blocks, num_kv_heads, block_size]
x,
k_stride0,
v_stride0,
block_size,
head_size,
num_kv_heads,
BLOCK_SIZE: tl.constexpr,
QUANT: tl.constexpr,
IS_FNUZ: tl.constexpr,
):
tid = tl.program_id(0)
head_id = tl.program_id(1)
offset = tl.arange(0, BLOCK_SIZE)
src_offset_k = tid * k_stride0 + head_id * head_size
src_offset_v = tid * v_stride0 + head_id * head_size
slot_id = tl.load(slot_mapping_ptr + tid)
if slot_id < 0:
return
block_id = slot_id // block_size
block_offset = slot_id % block_size
dst_offset = (
block_id * num_kv_heads * head_size * block_size
+ head_id * head_size * block_size
)
dst_k_shuffle_offset = (
dst_offset + offset // x * block_size * x + block_offset * x + offset % x
)
dst_v_shuffle_offset = (
dst_offset
+ block_offset // x * head_size * x
+ offset * x
+ block_offset % x
)
k_val = tl.load(key_ptr + src_offset_k + offset)
v_val = tl.load(value_ptr + src_offset_v + offset)
if QUANT:
k_scale = 1.0
v_scale = 1.0
k_dtype = key_cache_ptr.type.element_ty
v_dtype = value_cache_ptr.type.element_ty
k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype)
v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype)
tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val)
tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val)
def reshape_and_cache_shuffle_triton(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scales: torch.Tensor,
v_scales: torch.Tensor,
):
num_tokens = slot_mapping.shape[0]
_, num_kv_heads, head_size = key.shape
num_blocks, block_size, _, _ = key_cache.shape
x = 16 // key_cache.element_size()
k_cache_template = torch.empty(
[num_blocks, num_kv_heads, head_size // x, block_size, x],
dtype=key_cache.dtype,
device="meta",
)
v_cache_template = torch.empty(
[num_blocks, num_kv_heads, block_size // x, head_size, x],
dtype=value_cache.dtype,
device="meta",
)
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
QUANT = False
if kv_cache_dtype.startswith("fp8"):
QUANT = True
grid = (
num_tokens,
num_kv_heads,
)
reshape_and_cache_shuffle_kernel[grid](
key,
value,
new_key_cache,
new_value_cache,
slot_mapping,
k_scales,
v_scales,
x,
key.stride(0),
value.stride(0),
block_size,
head_size,
num_kv_heads,
BLOCK_SIZE=head_size,
QUANT=QUANT,
IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz,
)
logger = init_logger(__name__)
@@ -253,6 +388,11 @@ class AiterFlashAttentionMetadata:
common_prefix_len: int
total_tokens: int
# Only for fp8 shuffle layout kv cache, we allocate kv_scale for each layer
# since we might integrate per token quant for kv cache in the future.
k_scale: dict[str, torch.Tensor] | None
v_scale: dict[str, torch.Tensor] | None
class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
@@ -303,6 +443,7 @@ class AiterFlashAttentionMetadataBuilder(
dtype=self.model_config.dtype,
device=device,
)
self.scale = torch.tensor([1.0], dtype=torch.float, device=self.device)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
@@ -325,7 +466,27 @@ class AiterFlashAttentionMetadataBuilder(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
)
# Allocate scales for fp8 shuffle kv cache with shuffle_kv_cache enabled
if (
rocm_aiter_ops.is_shuffle_kv_cache_enabled()
and self.scale.numel() == 1
and self.vllm_config.cache_config.cache_dtype.startswith("fp8")
):
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
first_layer_name = [k for k in layers][0]
kv_cache_shape = (
self.vllm_config.compilation_config.static_forward_context[
first_layer_name
]
.kv_cache[0]
.shape
)
num_blocks = kv_cache_shape[1]
self.scale = torch.ones(
[num_blocks, self.num_heads_kv, self.block_size],
dtype=torch.float32,
device=self.device,
)
(
num_decodes,
num_extends,
@@ -507,6 +668,8 @@ class AiterFlashAttentionMetadataBuilder(
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
total_tokens=self.total_tokens,
k_scale=self.scale,
v_scale=self.scale,
)
return attn_metadata
@@ -548,7 +711,6 @@ class AiterFlashAttentionBackend(AttentionBackend):
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@@ -630,7 +792,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv=swa_cu_seqlens,
token_to_batch=swa_token_to_batch,
seq_starts=swa_seq_starts,
dequant=False,
dequant=self.kv_cache_dtype.startswith("fp8"),
kv_cache_layout="NHD",
total_tokens=swa_total_tokens,
)
@@ -668,8 +830,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
min_seqlen_q: int,
block_table: torch.Tensor,
slot_mapping: torch.Tensor,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
):
if self.sliding_window[0] != -1:
self.extend_for_sliding_window(
@@ -725,8 +887,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
token_to_batch=token_to_batch[chunk_idx],
seq_starts=chunk_starts[chunk_idx],
dequant=False,
kv_cache_layout="NHD",
dequant=self.kv_cache_dtype.startswith("fp8"),
kv_cache_layout="SHUFFLE"
if rocm_aiter_ops.is_shuffle_kv_cache_enabled()
else "NHD",
total_tokens=total_token_per_batch[chunk_idx],
)
@@ -823,6 +987,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if (
self.kv_sharing_target_layer_name is None
and key is not None
@@ -835,7 +1002,21 @@ class AiterFlashAttentionImpl(AttentionImpl):
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
# We may calculate per token quant scale in
# reshape_and_cache_shuffle_triton which might differ from
# vllm's style when shuffle layout is used.
reshape_and_cache_shuffle_triton(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
attn_metadata.k_scale,
attn_metadata.v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
@@ -847,10 +1028,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
# decode:extend:prefill
query = query[:num_actual_tokens]
if key is not None:
@@ -902,6 +1079,11 @@ class AiterFlashAttentionImpl(AttentionImpl):
extend_keys = key[extend_tokens_slice]
extend_values = value[extend_tokens_slice]
extend_outputs = output[extend_tokens_slice]
k_scale = layer._k_scale
v_scale = layer._v_scale
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
k_scale = attn_metadata.k_scale
v_scale = attn_metadata.v_scale
self.extend_forward(
attn_metadata=attn_metadata,
query=extend_querys,
@@ -920,14 +1102,17 @@ class AiterFlashAttentionImpl(AttentionImpl):
slot_mapping=attn_metadata.slot_mapping[
num_decodes : num_decodes + num_extends
],
k_scale=layer._k_scale,
v_scale=layer._v_scale,
k_scale=k_scale,
v_scale=v_scale,
)
# calculate for decodes
if num_decodes > 0:
assert attn_metadata.decode_metadata is not None
if self.sliding_window[0] != -1:
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
"Sliding window with shuffle layout is not supported yet."
)
from aiter.ops.triton.unified_attention import (
unified_attention,
)
@@ -957,6 +1142,36 @@ class AiterFlashAttentionImpl(AttentionImpl):
)
return
assert attn_metadata.decode_metadata is not None
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
x = 16 // key_cache.element_size()
k_cache_template = torch.empty(
[num_blocks, num_kv_heads, head_size // x, block_size, x],
dtype=key_cache.dtype,
device="meta",
)
v_cache_template = torch.empty(
[num_blocks, num_kv_heads, block_size // x, head_size, x],
dtype=value_cache.dtype,
device="meta",
)
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
aiter.pa_fwd_asm(
Q=query[:num_decode_tokens],
K=new_key_cache,
V=new_value_cache,
block_tables=attn_metadata.block_table[:num_decodes],
context_lens=attn_metadata.seq_lens[:num_decodes],
block_tables_stride0=attn_metadata.block_table[
:num_decodes
].stride(0),
K_QScale=attn_metadata.k_scale,
V_QScale=attn_metadata.v_scale,
out_=output[:num_decode_tokens],
)
else:
_, num_heads, head_size = query.shape
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
num_seqs = attn_metadata.seq_lens.shape[0]