[Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
@@ -129,6 +129,8 @@ class PassConfig:
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}, where key is the device capability"""
|
||||
enable_qk_norm_rope_fusion: bool = False
|
||||
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
@@ -182,6 +184,12 @@ class PassConfig:
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||
)
|
||||
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
|
||||
logger.warning_once(
|
||||
"QK Norm + RoPE fusion enabled but the current platform is not "
|
||||
"CUDA. The fusion will be disabled."
|
||||
)
|
||||
self.enable_qk_norm_rope_fusion = False
|
||||
|
||||
|
||||
@config
|
||||
@@ -640,6 +648,11 @@ class CompilationConfig:
|
||||
if isinstance(self.pass_config, dict):
|
||||
self.pass_config = PassConfig(**self.pass_config)
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
# TODO(zhuhaoran): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
|
||||
if (
|
||||
is_torch_equal_or_newer("2.9.0.dev")
|
||||
and "combo_kernels" not in self.inductor_compile_config
|
||||
|
||||
Reference in New Issue
Block a user