[ROCm] AITER fused RoPE+KVCache (#33443)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Co-authored-by: charlifu <charlifu@amd.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
This commit is contained in:
Rohan Potdar
2026-02-23 21:06:00 -06:00
committed by GitHub
parent 95642441d0
commit 2ff4e51152
19 changed files with 1211 additions and 83 deletions

View File

@@ -11,7 +11,6 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention.attention import get_attention_context
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count
@@ -1290,11 +1289,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
attn_metadata, _, _ = get_attention_context(layer.layer_name)
if attn_metadata is None:
# Profiling run.
return
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
@@ -1303,45 +1297,40 @@ class AiterFlashAttentionImpl(AttentionImpl):
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
# We may calculate per token quant scale in
# reshape_and_cache_shuffle_triton which might differ from
# vllm's style when shuffle layout is used.
k_scale = attn_metadata.k_scale
v_scale = attn_metadata.v_scale
assert k_scale is not None and v_scale is not None, (
"k_scale and v_scale are required for shuffled update"
)
reshape_and_cache_shuffle_triton(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
# We may calculate per token quant scale in
# reshape_and_cache_shuffle_triton which might differ from
# vllm's style when shuffle layout is used.
k_scale = layer._k_scale
v_scale = layer._v_scale
assert k_scale is not None and v_scale is not None, (
"k_scale and v_scale are required for shuffled update"
)
reshape_and_cache_shuffle_triton(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)