[Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhrrr
2025-11-12 01:00:31 +08:00
committed by GitHub
parent a7ef3eb0cd
commit 68c09efc37
16 changed files with 1243 additions and 38 deletions

View File

@@ -129,6 +129,8 @@ class PassConfig:
8: 1, # 1MB
},
}, where key is the device capability"""
enable_qk_norm_rope_fusion: bool = False
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
# TODO(luka) better pass enabling system.
@@ -182,6 +184,12 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
logger.warning_once(
"QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA. The fusion will be disabled."
)
self.enable_qk_norm_rope_fusion = False
@config
@@ -640,6 +648,11 @@ class CompilationConfig:
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
if self.pass_config.enable_qk_norm_rope_fusion:
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
is_torch_equal_or_newer("2.9.0.dev")
and "combo_kernels" not in self.inductor_compile_config