[ROCm] [aiter] Split KV cache update for AiterFlashAttention (#33681)

Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
This commit is contained in:
kliuae
2026-02-12 00:26:44 +08:00
committed by GitHub
parent fd618871b4
commit 64f570ab56

View File

@@ -11,6 +11,7 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention.attention import get_attention_context
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count
@@ -687,6 +688,8 @@ class AiterFlashAttentionBackend(AttentionBackend):
def get_supported_head_sizes(cls) -> list[int]:
return [64, 128, 256]
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@@ -982,49 +985,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
# We may calculate per token quant scale in
# reshape_and_cache_shuffle_triton which might differ from
# vllm's style when shuffle layout is used.
reshape_and_cache_shuffle_triton(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
attn_metadata.k_scale,
attn_metadata.v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# decode:extend:prefill
query = query[:num_actual_tokens]
@@ -1215,3 +1179,67 @@ class AiterFlashAttentionImpl(AttentionImpl):
)
return output
def do_kv_cache_update(
self,
layer: Attention,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
attn_metadata, _, _ = get_attention_context(layer.layer_name)
if attn_metadata is None:
# Profiling run.
return
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
# We may calculate per token quant scale in
# reshape_and_cache_shuffle_triton which might differ from
# vllm's style when shuffle layout is used.
k_scale = attn_metadata.k_scale
v_scale = attn_metadata.v_scale
assert k_scale is not None and v_scale is not None, (
"k_scale and v_scale are required for shuffled update"
)
reshape_and_cache_shuffle_triton(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)