From 64f570ab56cab7e8977c611b78f9a44a9a9f033c Mon Sep 17 00:00:00 2001 From: kliuae <17350011+kliuae@users.noreply.github.com> Date: Thu, 12 Feb 2026 00:26:44 +0800 Subject: [PATCH] [ROCm] [aiter] Split KV cache update for AiterFlashAttention (#33681) Signed-off-by: kliuae --- vllm/v1/attention/backends/rocm_aiter_fa.py | 108 ++++++++++++-------- 1 file changed, 68 insertions(+), 40 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 28b5a7f41..4be650f93 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -11,6 +11,7 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention.attention import get_attention_context from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import get_cu_count @@ -687,6 +688,8 @@ class AiterFlashAttentionBackend(AttentionBackend): def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] + forward_includes_kv_cache_update: bool = False + @staticmethod def get_name() -> str: return "FLASH_ATTN" @@ -982,49 +985,10 @@ class AiterFlashAttentionImpl(AttentionImpl): # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(0) - # key and value may be None in the case of cross attention. They are - # calculated once based on the output from the encoder and then cached - # in KV cache. + if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype()) - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping - # is not padded. However, we don't need to do - # key[:num_actual_tokens] and value[:num_actual_tokens] because - # the reshape_and_cache_flash op uses the slot_mapping's shape - # to determine the number of actual tokens. - if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): - # We may calculate per token quant scale in - # reshape_and_cache_shuffle_triton which might differ from - # vllm's style when shuffle layout is used. - reshape_and_cache_shuffle_triton( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - attn_metadata.k_scale, - attn_metadata.v_scale, - ) - else: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) # decode:extend:prefill query = query[:num_actual_tokens] @@ -1215,3 +1179,67 @@ class AiterFlashAttentionImpl(AttentionImpl): ) return output + + def do_kv_cache_update( + self, + layer: Attention, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ): + attn_metadata, _, _ = get_attention_context(layer.layer_name) + if attn_metadata is None: + # Profiling run. + return + + key_cache, value_cache = kv_cache.unbind(0) + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping + # is not padded. However, we don't need to do + # key[:num_actual_tokens] and value[:num_actual_tokens] because + # the reshape_and_cache_flash op uses the slot_mapping's shape + # to determine the number of actual tokens. + if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): + # We may calculate per token quant scale in + # reshape_and_cache_shuffle_triton which might differ from + # vllm's style when shuffle layout is used. + k_scale = attn_metadata.k_scale + v_scale = attn_metadata.v_scale + assert k_scale is not None and v_scale is not None, ( + "k_scale and v_scale are required for shuffled update" + ) + reshape_and_cache_shuffle_triton( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + else: + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + )