[ROCm] [aiter] Split KV cache update for AiterFlashAttention (#33681)
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.attention.attention import get_attention_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
@@ -687,6 +688,8 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [64, 128, 256]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
@@ -982,49 +985,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||
# is not padded. However, we don't need to do
|
||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||
# to determine the number of actual tokens.
|
||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||
# We may calculate per token quant scale in
|
||||
# reshape_and_cache_shuffle_triton which might differ from
|
||||
# vllm's style when shuffle layout is used.
|
||||
reshape_and_cache_shuffle_triton(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
attn_metadata.k_scale,
|
||||
attn_metadata.v_scale,
|
||||
)
|
||||
else:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# decode:extend:prefill
|
||||
query = query[:num_actual_tokens]
|
||||
@@ -1215,3 +1179,67 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: Attention,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
attn_metadata, _, _ = get_attention_context(layer.layer_name)
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||
# is not padded. However, we don't need to do
|
||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||
# to determine the number of actual tokens.
|
||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||
# We may calculate per token quant scale in
|
||||
# reshape_and_cache_shuffle_triton which might differ from
|
||||
# vllm's style when shuffle layout is used.
|
||||
k_scale = attn_metadata.k_scale
|
||||
v_scale = attn_metadata.v_scale
|
||||
assert k_scale is not None and v_scale is not None, (
|
||||
"k_scale and v_scale are required for shuffled update"
|
||||
)
|
||||
reshape_and_cache_shuffle_triton(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user