[ROCm] AITER fused RoPE+KVCache (#33443)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Co-authored-by: charlifu <charlifu@amd.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
This commit is contained in:
Rohan Potdar
2026-02-23 21:06:00 -06:00
committed by GitHub
parent 95642441d0
commit 2ff4e51152
19 changed files with 1211 additions and 83 deletions

View File

@@ -127,6 +127,13 @@ class PassConfig:
# ROCm/AITER specific fusions
fuse_act_padding: bool = Field(default=None)
"""Fuse the custom RMSNorm + padding ops."""
fuse_rope_kvcache: bool = Field(default=None)
"""Fuse the QK rope + KV cache ops."""
rope_kvcache_fusion_max_token_num: int = 256
"""The threshold for ROCm AITER RoPE+KVCache fusion e.g. for small batch decode.
Larger batch sizes e.g. during prefill will use the unfused kernels.
"""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
@@ -198,6 +205,7 @@ class PassConfig:
"fuse_gemm_comms",
"fuse_allreduce_rms",
"fuse_act_padding",
"fuse_rope_kvcache",
mode="wrap",
)
@classmethod
@@ -243,6 +251,12 @@ class PassConfig:
"The fusion will be disabled."
)
self.fuse_act_padding = False
if self.fuse_rope_kvcache and not current_platform.is_rocm():
logger.warning_once(
"KV cache fusion currently only enabled on ROCm. "
"The fusion will be disabled."
)
self.fuse_rope_kvcache = False
class DynamicShapesType(str, enum.Enum):
@@ -824,6 +838,19 @@ class CompilationConfig:
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if self.pass_config.fuse_rope_kvcache:
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
logger.warning(
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
)
self.pass_config.fuse_rope_kvcache = False
else:
# TODO(Rohan138): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
is_torch_equal_or_newer("2.9.0.dev")

View File

@@ -1401,6 +1401,20 @@ class VllmConfig:
"allreduce-rms fusion will be enabled for all num_tokens."
)
if compilation_config.pass_config.fuse_rope_kvcache:
max_token_num = (
compilation_config.pass_config.rope_kvcache_fusion_max_token_num
)
if max_token_num is not None:
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below rope+kvcache fusion threshold, "
"rope+kvcache fusion enabled for num_tokens <= %d.",
compile_range_end,
)
if compilation_config.compile_ranges_split_points is not None:
for x in compilation_config.compile_ranges_split_points:
assert isinstance(x, int)