[Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
141
tests/kernels/core/test_fused_qk_norm_rope.py
Normal file
141
tests/kernels/core/test_fused_qk_norm_rope.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
IS_NEOX = [True, False]
|
||||
EPS_VALUES = [1e-5, 1e-6]
|
||||
SEEDS = [13]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
def _apply_qk_norm_rope(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_norm: RMSNorm,
|
||||
k_norm: RMSNorm,
|
||||
rope: RotaryEmbedding,
|
||||
num_heads_q: int,
|
||||
num_heads_kv: int,
|
||||
head_dim: int,
|
||||
) -> torch.Tensor:
|
||||
q_size = num_heads_q * head_dim
|
||||
kv_size = num_heads_kv * head_dim
|
||||
|
||||
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
||||
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
|
||||
q_by_head = q_norm.forward_native(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
|
||||
k_by_head = k_norm.forward_native(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
|
||||
q, k = rope.forward_native(positions, q, k)
|
||||
return torch.cat([q, k, v], dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="fused_qk_norm_rope custom op requires cuda platform",
|
||||
)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("is_neox", IS_NEOX)
|
||||
@pytest.mark.parametrize("eps", EPS_VALUES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_fused_qk_norm_rope_matches_reference(
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
is_neox: bool,
|
||||
eps: float,
|
||||
seed: int,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
num_heads, num_kv_heads, head_dim = 16, 4, 128
|
||||
num_tokens = 4
|
||||
|
||||
total_dim = (num_heads + 2 * num_kv_heads) * head_dim
|
||||
qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device)
|
||||
qkv_fused = qkv_base.clone()
|
||||
positions = torch.arange(num_tokens, dtype=torch.long, device=device)
|
||||
|
||||
q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
|
||||
k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype)
|
||||
q_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
k_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
q_weight = q_norm.weight.data
|
||||
k_weight = k_norm.weight.data
|
||||
|
||||
rope = RotaryEmbedding(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000.0,
|
||||
is_neox_style=is_neox,
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
ref_result = _apply_qk_norm_rope(
|
||||
qkv=qkv_base,
|
||||
positions=positions,
|
||||
q_norm=q_norm,
|
||||
k_norm=k_norm,
|
||||
rope=rope,
|
||||
num_heads_q=num_heads,
|
||||
num_heads_kv=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.fused_qk_norm_rope,
|
||||
(
|
||||
qkv_fused.clone(),
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
eps,
|
||||
q_weight,
|
||||
k_weight,
|
||||
rope.cos_sin_cache,
|
||||
is_neox,
|
||||
positions.view(-1),
|
||||
),
|
||||
)
|
||||
|
||||
torch.ops._C.fused_qk_norm_rope(
|
||||
qkv_fused,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
eps,
|
||||
q_weight,
|
||||
k_weight,
|
||||
rope.cos_sin_cache,
|
||||
is_neox,
|
||||
positions.view(-1),
|
||||
)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(
|
||||
qkv_fused,
|
||||
ref_result,
|
||||
atol=ATOL,
|
||||
rtol=RTOL,
|
||||
)
|
||||
Reference in New Issue
Block a user