[Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhrrr
2025-11-12 01:00:31 +08:00
committed by GitHub
parent a7ef3eb0cd
commit 68c09efc37
16 changed files with 1243 additions and 38 deletions

View File

@@ -329,6 +329,7 @@ def rms_norm(
out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
# If removed, also need to remove contiguous in MatcherRMSNorm
input_contiguous = input.contiguous()
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
@@ -339,6 +340,34 @@ def fused_add_rms_norm(
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
def fused_qk_norm_rope(
qkv: torch.Tensor,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_dim: int,
eps: float,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
position_ids: torch.Tensor,
) -> None:
torch.ops._C.fused_qk_norm_rope(
qkv,
num_heads_q,
num_heads_k,
num_heads_v,
head_dim,
eps,
q_weight,
k_weight,
cos_sin_cache,
is_neox,
position_ids,
)
def apply_repetition_penalties_torch(
logits: torch.Tensor,
prompt_mask: torch.Tensor,

View File

@@ -132,6 +132,23 @@ class FixFunctionalizationPass(VllmInductorPass):
"input_global_scale",
),
)
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
mutated_args = {1: "qkv"}
args = (
"qkv",
"num_heads_q",
"num_heads_k",
"num_heads_v",
"head_dim",
"eps",
"q_weight",
"k_weight",
"cos_sin_cache",
"is_neox",
"position_ids",
)
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
else:
continue # skip the count

View File

@@ -44,6 +44,10 @@ def empty_i32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
def empty_i64(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

View File

@@ -18,10 +18,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
@@ -58,6 +61,9 @@ class MatcherCustomOp(ABC):
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
def empty_int64(self, *args, **kws):
return torch.empty(*args, dtype=torch.int64, device=self.device, **kws)
def empty_f32(self, *args, **kws):
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
@@ -66,6 +72,77 @@ class MatcherCustomOp(ABC):
raise NotImplementedError
class MatcherRotaryEmbedding(MatcherCustomOp):
def __init__(
self,
is_neox: bool,
head_size: int,
num_heads: int,
num_kv_heads: int,
use_flashinfer: bool = False,
enabled: bool | None = None,
) -> None:
if enabled is None:
enabled = RotaryEmbedding.enabled()
super().__init__(enabled)
self.is_neox = is_neox
self.head_size = head_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.q_size = self.num_heads * self.head_size
self.kv_size = self.num_kv_heads * self.head_size
self.rotary_dim = head_size
if use_flashinfer:
self.rotary_op = FLASHINFER_ROTARY_OP
else:
self.rotary_op = ROTARY_OP
def inputs(self) -> list[torch.Tensor]:
positions = self.empty_int64(5)
query = self.empty(5, self.q_size)
key = self.empty(5, self.kv_size)
cos_sin_cache = self.empty(4096, self.rotary_dim)
return [positions, query, key, cos_sin_cache]
def forward_custom(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
result = auto_functionalized(
self.rotary_op,
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
)
query_out = result[1]
key_out = result[2] if len(result) > 2 else None
return query_out, key_out
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return RotaryEmbedding.forward_static(
positions,
query,
key,
self.head_size,
self.rotary_dim,
cos_sin_cache,
self.is_neox,
)
class MatcherRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
@@ -85,10 +162,12 @@ class MatcherRMSNorm(MatcherCustomOp):
weight: torch.Tensor,
) -> torch.Tensor:
result = torch.empty_like(input)
# TODO: support non-contiguous input for RMSNorm and remove this
input_contiguous = input.contiguous()
_, result = auto_functionalized(
RMS_OP,
result=result,
input=input,
input=input_contiguous,
weight=weight,
epsilon=self.epsilon,
)

View File

@@ -17,6 +17,7 @@ if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass
from .qk_norm_rope_fusion import QKNormRoPEFusionPass
if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
@@ -109,6 +110,9 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_qk_norm_rope_fusion:
self.passes += [QKNormRoPEFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)

View File

@@ -0,0 +1,238 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.attention import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from .fusion import empty_bf16, empty_fp32, empty_i64
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
class QkNormRopePattern:
"""
Match the unfused sequence in attention blocks and replace with the fused op.
Unfused (conceptually):
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
qh = reshape(q, [-1, num_heads, head_dim])
kh = reshape(k, [-1, num_kv_heads, head_dim])
qn = rms_norm(qh, q_weight, eps)
kn = rms_norm(kh, k_weight, eps)
qf = reshape(qn, [-1, num_heads * head_dim])
kf = reshape(kn, [-1, num_kv_heads * head_dim])
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
return qf, kf, v
Fused replacement:
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
eps, q_weight, k_weight, cos_sin_cache, is_neox,
positions.view(-1))
return split(qkv, [qsz, kvsz, kvsz], -1)
"""
def __init__(
self,
head_dim: int,
num_heads: int,
num_kv_heads: int,
eps: float,
is_neox: bool,
rope_flashinfer: bool = False,
) -> None:
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
self.rmsnorm_matcher = MatcherRMSNorm(eps)
self.is_neox = is_neox
self.rope_flashinfer = rope_flashinfer
self.rope_matcher = MatcherRotaryEmbedding(
is_neox=is_neox,
head_size=self.head_dim,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
use_flashinfer=self.rope_flashinfer,
)
def get_inputs(self):
# Sample inputs to help pattern tracing
T = 5
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
positions = empty_i64(T)
q_weight = empty_bf16(1, self.head_dim)
k_weight = empty_bf16(1, self.head_dim)
if self.rope_flashinfer:
cos_sin_cache = empty_fp32(4096, self.head_dim)
else:
cos_sin_cache = empty_bf16(4096, self.head_dim)
return [
qkv,
positions,
q_weight,
k_weight,
cos_sin_cache,
]
@staticmethod
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
# split qkv -> q,k,v
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Q path: view -> RMS -> view back to q.shape
q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
)
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
q_flat = q_normed_by_head.view(q.shape)
# K path: view -> RMS -> view back to k.shape
k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
)
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
k_flat = k_normed_by_head.view(k.shape)
# RoPE: apply to flattened q/k
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
return q_rope, k_rope, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
# Run fused qk_norm_rope op
result = auto_functionalized(
FUSED_QK_ROPE_OP,
qkv=qkv,
num_heads_q=self.num_heads,
num_heads_k=self.num_kv_heads,
num_heads_v=self.num_kv_heads,
head_dim=self.head_dim,
eps=self.eps,
q_weight=q_weight,
k_weight=k_weight,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
position_ids=positions.view(-1),
)
result_qkv = result[1]
# Split back to q,k,v and return
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
pm.register_replacement(
pattern,
replacement,
self.get_inputs(),
QkNormRopePattern.wrap_trace_fn(
pm.fwd_only,
QkNormRopePattern.fx_view_to_reshape,
),
pm_pass,
)
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="qk_norm_rope_fusion_pass"
)
dtype = config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logger.warning_once(
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
)
return
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
config, Attention
)
if len(attn_layers) == 0:
logger.warning_once(
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
)
return
layer = next(iter(attn_layers.values()))
for epsilon in [1e-5, 1e-6]:
for neox in [True, False]:
if RotaryEmbedding.enabled():
for rope_flashinfer in [False, True]:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
rope_flashinfer=rope_flashinfer,
).register(self.patterns)
else:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
def uuid(self):
return VllmInductorPass.hash_source(self, QkNormRopePattern)

View File

@@ -129,6 +129,8 @@ class PassConfig:
8: 1, # 1MB
},
}, where key is the device capability"""
enable_qk_norm_rope_fusion: bool = False
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
# TODO(luka) better pass enabling system.
@@ -182,6 +184,12 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda():
logger.warning_once(
"QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA. The fusion will be disabled."
)
self.enable_qk_norm_rope_fusion = False
@config
@@ -640,6 +648,11 @@ class CompilationConfig:
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
if self.pass_config.enable_qk_norm_rope_fusion:
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
is_torch_equal_or_newer("2.9.0.dev")
and "combo_kernels" not in self.inductor_compile_config

View File

@@ -98,6 +98,39 @@ class RotaryEmbedding(RotaryEmbeddingBase):
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
@staticmethod
def forward_static(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
head_size: int,
rotary_dim: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""A PyTorch-native implementation of forward()."""
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_native(
self,
positions: torch.Tensor,
@@ -105,27 +138,15 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""A PyTorch-native implementation of forward()."""
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
return self.forward_static(
positions,
query,
key,
self.head_size,
self.rotary_dim,
self.cos_sin_cache,
self.is_neox_style,
)
def forward_cuda(
self,