Enable RoPE+KV cache fusion for ROCm AITER FA (non-shuffle layout) (#35786)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -196,6 +196,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
@@ -1308,7 +1309,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: Attention,
|
||||
layer: AttentionLayer,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
@@ -1359,3 +1360,47 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
# Only support fusion when shuffle KV cache layout is not used;
|
||||
# shuffle layout uses a different cache update path.
|
||||
return (
|
||||
rocm_aiter_ops.is_enabled()
|
||||
and not rocm_aiter_ops.is_shuffle_kv_cache_enabled()
|
||||
)
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
flash_layout = True
|
||||
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
rocm_aiter_ops.triton_rope_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
key_cache,
|
||||
value_cache,
|
||||
layer_slot_mapping,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
flash_layout,
|
||||
is_fp8_kv_cache,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user