[Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhrrr
2025-11-12 01:00:31 +08:00
committed by GitHub
parent a7ef3eb0cd
commit 68c09efc37
16 changed files with 1243 additions and 38 deletions

View File

@@ -98,6 +98,39 @@ class RotaryEmbedding(RotaryEmbeddingBase):
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
@staticmethod
def forward_static(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
head_size: int,
rotary_dim: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""A PyTorch-native implementation of forward()."""
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_native(
self,
positions: torch.Tensor,
@@ -105,27 +138,15 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""A PyTorch-native implementation of forward()."""
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
return self.forward_static(
positions,
query,
key,
self.head_size,
self.rotary_dim,
self.cos_sin_cache,
self.is_neox_style,
)
def forward_cuda(
self,