Enable RoPE+KV cache fusion for ROCm AITER FA (non-shuffle layout) (#35786)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-03-13 02:33:22 -05:00
committed by GitHub
parent b373b5102a
commit a4ad9db541
2 changed files with 47 additions and 1 deletions

View File

@@ -196,6 +196,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.ROCM_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
],
)
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])

View File

@@ -20,6 +20,7 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
@@ -1308,7 +1309,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
def do_kv_cache_update(
self,
layer: Attention,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
@@ -1359,3 +1360,47 @@ class AiterFlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
def fused_rope_kvcache_supported(self):
# Only support fusion when shuffle KV cache layout is not used;
# shuffle layout uses a different cache update path.
return (
rocm_aiter_ops.is_enabled()
and not rocm_aiter_ops.is_shuffle_kv_cache_enabled()
)
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)