[ROCm]: Enable customop and rope+kvcache fusion for AITER RoPE (#35180)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -123,6 +123,8 @@ class PassConfig:
|
||||
"""Enable async TP."""
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
"""Enable flashinfer allreduce fusion."""
|
||||
enable_qk_norm_rope_fusion: bool = False
|
||||
"""Enable fused Q/K RMSNorm + RoPE pass."""
|
||||
|
||||
# ROCm/AITER specific fusions
|
||||
fuse_act_padding: bool = Field(default=None)
|
||||
@@ -153,8 +155,6 @@ class PassConfig:
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}, where key is the device capability"""
|
||||
enable_qk_norm_rope_fusion: bool = False
|
||||
"""Enable fused Q/K RMSNorm + RoPE pass."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
@@ -834,23 +834,20 @@ class CompilationConfig:
|
||||
func if isinstance(func, InductorPass) else CallableInductorPass(func)
|
||||
)
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
if (
|
||||
self.pass_config.enable_qk_norm_rope_fusion
|
||||
and "+rotary_embedding" not in self.custom_ops
|
||||
):
|
||||
# TODO(zhuhaoran): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
if self.pass_config.fuse_rope_kvcache:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
|
||||
logger.warning(
|
||||
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
|
||||
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
|
||||
)
|
||||
self.pass_config.fuse_rope_kvcache = False
|
||||
else:
|
||||
# TODO(Rohan138): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
if (
|
||||
self.pass_config.fuse_rope_kvcache
|
||||
and "+rotary_embedding" not in self.custom_ops
|
||||
):
|
||||
# TODO(Rohan138): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
|
||||
if (
|
||||
is_torch_equal_or_newer("2.9.0.dev")
|
||||
|
||||
Reference in New Issue
Block a user