From 130d6c9514856cb5a152329f0382d60ff6e8d97e Mon Sep 17 00:00:00 2001 From: Pleaplusone Date: Thu, 15 Jan 2026 23:29:53 +0800 Subject: [PATCH] [ROCm][Perf] Enable shuffle kv cache layout and assembly paged attention kernel for `AiterFlashAttentionBackend` (#29887) Signed-off-by: ganyi --- vllm/_aiter_ops.py | 7 + vllm/envs.py | 5 + vllm/v1/attention/backends/rocm_aiter_fa.py | 379 +++++++++++++++----- 3 files changed, 309 insertions(+), 82 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index b443f7735..c9ad8a5ae 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -833,6 +833,7 @@ class rocm_aiter_ops: _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA + _SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM @@ -859,6 +860,7 @@ class rocm_aiter_ops: cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA + cls._SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM @@ -906,6 +908,11 @@ class rocm_aiter_ops: def is_mha_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._MHA_ENABLED + @classmethod + @if_aiter_supported + def is_shuffle_kv_cache_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED + @classmethod @if_aiter_supported def is_triton_unified_attn_enabled(cls) -> bool: diff --git a/vllm/envs.py b/vllm/envs.py index ca4bda46f..65bbd29f3 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -128,6 +128,7 @@ if TYPE_CHECKING: VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -1018,6 +1019,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: ( os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1") ), + # Whether to use the shuffled kv cache layout + "VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: ( + os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1") + ), # Custom quick allreduce kernel for MI3* cards # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index f384aaa46..3febbe57a 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -7,6 +7,7 @@ from typing import ClassVar import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger @@ -30,7 +31,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 _CP_TOKENS_PER_ITER_ROCM = 32 * 1024 - if current_platform.is_rocm(): import aiter @@ -52,7 +52,7 @@ if current_platform.is_rocm(): cu_seqlens_kv_ptr, # [num_batches + 1] token_to_batch_ptr, # [max_cum_tokens] seq_start_ptr, # [num_batches] - k_scale_ptr, + k_scale_ptr, # [1] / [num_blocks, num_kv_heads, page_size] v_scale_ptr, num_heads, head_size, @@ -64,13 +64,15 @@ if current_platform.is_rocm(): BLOCK_SIZE: tl.constexpr, ): token_id = tl.program_id(0) + head_id = tl.program_id(1) col_offsets = tl.arange(0, BLOCK_SIZE) - if DEQUANT: - k_scale = tl.load(k_scale_ptr) - v_scale = tl.load(v_scale_ptr) - key_ptr_offset = key_ptr + token_id * head_size * num_heads - value_ptr_offset = value_ptr + token_id * head_size * num_heads + key_ptr_offset = ( + key_ptr + token_id * head_size * num_heads + head_id * head_size + ) + value_ptr_offset = ( + value_ptr + token_id * head_size * num_heads + head_id * head_size + ) batch_idx = tl.load(token_to_batch_ptr + token_id) batch_start = tl.load(seq_start_ptr + batch_idx) token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) @@ -89,24 +91,54 @@ if current_platform.is_rocm(): key_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id * num_heads * head_size + + head_id * head_size ) value_cache_ptr_offset = ( value_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id * num_heads * head_size + + head_id * head_size ) + k_reg = tl.load(key_cache_ptr_offset + col_offsets) + v_reg = tl.load(value_cache_ptr_offset + col_offsets) + if DEQUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) - for i in tl.range(0, head_size * num_heads, BLOCK_SIZE): - mask = (col_offsets + i) < head_size * num_heads - k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask) - v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask) - if DEQUANT: - k_dtype = k_reg.dtype - v_dtype = v_reg.dtype - k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) - v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) - tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask) - tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask) + elif CACHE_FORMAT == "SHUFFLE": + # for kv cache layout as + # K: [num_blocks, num_head, head_dim // x, page_size, x] + # V: [num_blocks, num_head, page_size // x, head_dim, x] + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + slot_id * x + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + head_id * head_size * PAGE_SIZE + + (slot_id // x) * head_size * x + + slot_id % x + ) + k_reg_offset = col_offsets // x * PAGE_SIZE * x + col_offsets % x + v_reg_offset = col_offsets * x + k_reg = tl.load(key_cache_ptr_offset + k_reg_offset) + v_reg = tl.load(value_cache_ptr_offset + v_reg_offset) + if DEQUANT: + k_scale = 1.0 + v_scale = 1.0 + k_reg = k_reg.to(tl.float32) * k_scale + v_reg = v_reg.to(tl.float32) * v_scale + tl.store(key_ptr_offset + col_offsets, k_reg) + tl.store(value_ptr_offset + col_offsets, v_reg) def cp_mha_gather_cache( key_cache: torch.Tensor, @@ -123,17 +155,14 @@ if current_platform.is_rocm(): kv_cache_layout: str, total_tokens: int, ): - assert kv_cache_layout in ["v0", "NHD", "HND"], ( - "kv_cache_layout only support v0, NHD, HND" + assert kv_cache_layout in ["NHD", "SHUFFLE"], ( + "kv_cache_layout only support NHD, SHUFFLE" ) head_dim = key.shape[2] - x = 0 + x = 16 // key_cache.element_size() # assert dequant is True, "Currently, we only support "\ # "gather cache with dequant" # For k cache layout: [num_blocks, num_heads, page_size, head_dim] - assert kv_cache_layout == "NHD", ( - "ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now" - ) assert head_dim == key_cache.shape[3], ( "We assume your kv cache layout is [num_blocks, " "page_size, num_heads, head_dim], but got otherwise" @@ -141,7 +170,7 @@ if current_platform.is_rocm(): page_size = key_cache.shape[1] num_heads = key_cache.shape[2] - grid = lambda meta: (total_tokens,) + grid = lambda meta: (total_tokens, num_heads) cp_mha_gather_cache_kernel[grid]( key_cache, value_cache, @@ -163,6 +192,112 @@ if current_platform.is_rocm(): BLOCK_SIZE=head_dim, ) + @triton.jit + def reshape_and_cache_shuffle_kernel( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] + slot_mapping_ptr, # [num_tokens] + k_scale_ptr, # [num_blocks, num_kv_heads, block_size] + v_scale_ptr, # [num_blocks, num_kv_heads, block_size] + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, + IS_FNUZ: tl.constexpr, + ): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + + block_offset // x * head_size * x + + offset * x + + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = 1.0 + v_scale = 1.0 + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + + def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, + ): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + QUANT = False + if kv_cache_dtype.startswith("fp8"): + QUANT = True + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=QUANT, + IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, + ) + logger = init_logger(__name__) @@ -253,6 +388,11 @@ class AiterFlashAttentionMetadata: common_prefix_len: int total_tokens: int + # Only for fp8 shuffle layout kv cache, we allocate kv_scale for each layer + # since we might integrate per token quant for kv cache in the future. + k_scale: dict[str, torch.Tensor] | None + v_scale: dict[str, torch.Tensor] | None + class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata] @@ -303,6 +443,7 @@ class AiterFlashAttentionMetadataBuilder( dtype=self.model_config.dtype, device=device, ) + self.scale = torch.tensor([1.0], dtype=torch.float, device=self.device) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata @@ -325,7 +466,27 @@ class AiterFlashAttentionMetadataBuilder( common_attn_metadata, decode_threshold=self.reorder_batch_threshold, ) - + # Allocate scales for fp8 shuffle kv cache with shuffle_kv_cache enabled + if ( + rocm_aiter_ops.is_shuffle_kv_cache_enabled() + and self.scale.numel() == 1 + and self.vllm_config.cache_config.cache_dtype.startswith("fp8") + ): + layers = get_layers_from_vllm_config(self.vllm_config, Attention) + first_layer_name = [k for k in layers][0] + kv_cache_shape = ( + self.vllm_config.compilation_config.static_forward_context[ + first_layer_name + ] + .kv_cache[0] + .shape + ) + num_blocks = kv_cache_shape[1] + self.scale = torch.ones( + [num_blocks, self.num_heads_kv, self.block_size], + dtype=torch.float32, + device=self.device, + ) ( num_decodes, num_extends, @@ -507,6 +668,8 @@ class AiterFlashAttentionMetadataBuilder( use_cascade=use_cascade, common_prefix_len=common_prefix_len, total_tokens=self.total_tokens, + k_scale=self.scale, + v_scale=self.scale, ) return attn_metadata @@ -548,7 +711,6 @@ class AiterFlashAttentionBackend(AttentionBackend): ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) @@ -630,7 +792,7 @@ class AiterFlashAttentionImpl(AttentionImpl): cu_seqlens_kv=swa_cu_seqlens, token_to_batch=swa_token_to_batch, seq_starts=swa_seq_starts, - dequant=False, + dequant=self.kv_cache_dtype.startswith("fp8"), kv_cache_layout="NHD", total_tokens=swa_total_tokens, ) @@ -668,8 +830,8 @@ class AiterFlashAttentionImpl(AttentionImpl): min_seqlen_q: int, block_table: torch.Tensor, slot_mapping: torch.Tensor, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ): if self.sliding_window[0] != -1: self.extend_for_sliding_window( @@ -725,8 +887,10 @@ class AiterFlashAttentionImpl(AttentionImpl): cu_seqlens_kv=cu_seqlens_kv[chunk_idx], token_to_batch=token_to_batch[chunk_idx], seq_starts=chunk_starts[chunk_idx], - dequant=False, - kv_cache_layout="NHD", + dequant=self.kv_cache_dtype.startswith("fp8"), + kv_cache_layout="SHUFFLE" + if rocm_aiter_ops.is_shuffle_kv_cache_enabled() + else "NHD", total_tokens=total_token_per_batch[chunk_idx], ) @@ -823,6 +987,9 @@ class AiterFlashAttentionImpl(AttentionImpl): # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) if ( self.kv_sharing_target_layer_name is None and key is not None @@ -835,21 +1002,31 @@ class AiterFlashAttentionImpl(AttentionImpl): # key[:num_actual_tokens] and value[:num_actual_tokens] because # the reshape_and_cache_flash op uses the slot_mapping's shape # to determine the number of actual tokens. - - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(current_platform.fp8_dtype()) - value_cache = value_cache.view(current_platform.fp8_dtype()) + if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): + # We may calculate per token quant scale in + # reshape_and_cache_shuffle_triton which might differ from + # vllm's style when shuffle layout is used. + reshape_and_cache_shuffle_triton( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + attn_metadata.k_scale, + attn_metadata.v_scale, + ) + else: + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) # decode:extend:prefill query = query[:num_actual_tokens] @@ -902,6 +1079,11 @@ class AiterFlashAttentionImpl(AttentionImpl): extend_keys = key[extend_tokens_slice] extend_values = value[extend_tokens_slice] extend_outputs = output[extend_tokens_slice] + k_scale = layer._k_scale + v_scale = layer._v_scale + if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): + k_scale = attn_metadata.k_scale + v_scale = attn_metadata.v_scale self.extend_forward( attn_metadata=attn_metadata, query=extend_querys, @@ -920,14 +1102,17 @@ class AiterFlashAttentionImpl(AttentionImpl): slot_mapping=attn_metadata.slot_mapping[ num_decodes : num_decodes + num_extends ], - k_scale=layer._k_scale, - v_scale=layer._v_scale, + k_scale=k_scale, + v_scale=v_scale, ) # calculate for decodes if num_decodes > 0: assert attn_metadata.decode_metadata is not None if self.sliding_window[0] != -1: + assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), ( + "Sliding window with shuffle layout is not supported yet." + ) from aiter.ops.triton.unified_attention import ( unified_attention, ) @@ -957,41 +1142,71 @@ class AiterFlashAttentionImpl(AttentionImpl): ) return assert attn_metadata.decode_metadata is not None - _, num_heads, head_size = query.shape - nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 - num_seqs = attn_metadata.seq_lens.shape[0] - max_num_partitions = ( - attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1 - ) // _PARTITION_SIZE_ROCM - workspace_buffer = torch.empty( - (num_seqs * num_heads * max_num_partitions * head_size) - * nbytes_per_qo_elem - + 2 * (num_seqs * num_heads * max_num_partitions) * 4, - dtype=torch.uint8, - device=output.device, - ) + if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): + num_blocks, block_size, num_kv_heads, head_size = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + aiter.pa_fwd_asm( + Q=query[:num_decode_tokens], + K=new_key_cache, + V=new_value_cache, + block_tables=attn_metadata.block_table[:num_decodes], + context_lens=attn_metadata.seq_lens[:num_decodes], + block_tables_stride0=attn_metadata.block_table[ + :num_decodes + ].stride(0), + K_QScale=attn_metadata.k_scale, + V_QScale=attn_metadata.v_scale, + out_=output[:num_decode_tokens], + ) + else: + _, num_heads, head_size = query.shape + nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 + num_seqs = attn_metadata.seq_lens.shape[0] + max_num_partitions = ( + attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM - torch.ops.aiter.paged_attention_v1( - output[:num_decode_tokens], - workspace_buffer, - query[:num_decode_tokens], - key_cache, - value_cache, - self.scale, - attn_metadata.block_table[:num_decodes], - attn_metadata.query_start_loc[:num_decodes], - attn_metadata.seq_lens[:num_decodes], - attn_metadata.max_seq_len, - self.alibi_slopes, - self.kv_cache_dtype, - "NHD", - self.logits_soft_cap, - layer._k_scale, - layer._v_scale, - None, - _PARTITION_SIZE_ROCM, - ) + workspace_buffer = torch.empty( + (num_seqs * num_heads * max_num_partitions * head_size) + * nbytes_per_qo_elem + + 2 * (num_seqs * num_heads * max_num_partitions) * 4, + dtype=torch.uint8, + device=output.device, + ) + + torch.ops.aiter.paged_attention_v1( + output[:num_decode_tokens], + workspace_buffer, + query[:num_decode_tokens], + key_cache, + value_cache, + self.scale, + attn_metadata.block_table[:num_decodes], + attn_metadata.query_start_loc[:num_decodes], + attn_metadata.seq_lens[:num_decodes], + attn_metadata.max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + "NHD", + self.logits_soft_cap, + layer._k_scale, + layer._v_scale, + None, + _PARTITION_SIZE_ROCM, + ) else: raise NotImplementedError( "Cascade attention is not implemented for ROCM AITER"