[ROCm]: Enable customop and rope+kvcache fusion for AITER RoPE (#35180)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-02-24 22:36:40 -06:00
committed by GitHub
parent ec1d30c0f6
commit f38f8c9742
9 changed files with 139 additions and 67 deletions

View File

@@ -123,6 +123,8 @@ class PassConfig:
"""Enable async TP."""
fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion."""
enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass."""
# ROCm/AITER specific fusions
fuse_act_padding: bool = Field(default=None)
@@ -153,8 +155,6 @@ class PassConfig:
8: 1, # 1MB
},
}, where key is the device capability"""
enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass."""
# TODO(luka) better pass enabling system.
@@ -834,23 +834,20 @@ class CompilationConfig:
func if isinstance(func, InductorPass) else CallableInductorPass(func)
)
if self.pass_config.enable_qk_norm_rope_fusion:
if (
self.pass_config.enable_qk_norm_rope_fusion
and "+rotary_embedding" not in self.custom_ops
):
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if self.pass_config.fuse_rope_kvcache:
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
logger.warning(
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
)
self.pass_config.fuse_rope_kvcache = False
else:
# TODO(Rohan138): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
self.pass_config.fuse_rope_kvcache
and "+rotary_embedding" not in self.custom_ops
):
# TODO(Rohan138): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
is_torch_equal_or_newer("2.9.0.dev")