diff --git a/tests/compile/passes/test_rope_kvcache_fusion.py b/tests/compile/passes/test_rope_kvcache_fusion.py index 09679fb41..d9554f6fb 100644 --- a/tests/compile/passes/test_rope_kvcache_fusion.py +++ b/tests/compile/passes/test_rope_kvcache_fusion.py @@ -196,6 +196,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module): AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.ROCM_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, ], ) @pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False]) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index b1adaa724..e756766f4 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -20,6 +20,7 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, AttentionImpl, + AttentionLayer, AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, @@ -1308,7 +1309,7 @@ class AiterFlashAttentionImpl(AttentionImpl): def do_kv_cache_update( self, - layer: Attention, + layer: AttentionLayer, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, @@ -1359,3 +1360,47 @@ class AiterFlashAttentionImpl(AttentionImpl): layer._k_scale, layer._v_scale, ) + + def fused_rope_kvcache_supported(self): + # Only support fusion when shuffle KV cache layout is not used; + # shuffle layout uses a different cache update path. + return ( + rocm_aiter_ops.is_enabled() + and not rocm_aiter_ops.is_shuffle_kv_cache_enabled() + ) + + def do_rope_and_kv_cache_update( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + kv_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + ): + key_cache, value_cache = kv_cache.unbind(0) + flash_layout = True + + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) + + rocm_aiter_ops.triton_rope_and_cache( + query, + key, + value, + positions, + cos_sin_cache, + is_neox, + key_cache, + value_cache, + layer_slot_mapping, + layer._k_scale, + layer._v_scale, + flash_layout, + is_fp8_kv_cache, + )