[ROCm] AITER fused RoPE+KVCache (#33443)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Co-authored-by: charlifu <charlifu@amd.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
This commit is contained in:
Rohan Potdar
2026-02-23 21:06:00 -06:00
committed by GitHub
parent 95642441d0
commit 2ff4e51152
19 changed files with 1211 additions and 83 deletions

View File

@@ -7,6 +7,7 @@ from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
@@ -596,3 +597,42 @@ class TritonAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
def fused_rope_kvcache_supported(self):
return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)