[Kernel] Enable fused_qknorm_rope_kernel supports partial rope (#30821)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-12-22 10:39:22 +08:00
committed by GitHub
parent 7e065eba59
commit 097978a15d
2 changed files with 64 additions and 52 deletions

View File

@@ -13,6 +13,7 @@ DTYPES = [torch.bfloat16, torch.float16]
IS_NEOX = [True, False]
EPS_VALUES = [1e-5, 1e-6]
SEEDS = [13]
PARTIAL_ROPE = [True, False]
CUDA_DEVICES = ["cuda:0"]
@@ -52,6 +53,7 @@ def _apply_qk_norm_rope(
@pytest.mark.parametrize("is_neox", IS_NEOX)
@pytest.mark.parametrize("eps", EPS_VALUES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("rotary_ratio", [1.0, 0.5, 0.25])
@torch.inference_mode()
def test_fused_qk_norm_rope_matches_reference(
device: str,
@@ -59,6 +61,7 @@ def test_fused_qk_norm_rope_matches_reference(
is_neox: bool,
eps: float,
seed: int,
rotary_ratio: float,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
@@ -76,10 +79,10 @@ def test_fused_qk_norm_rope_matches_reference(
k_norm.weight.data.normal_(mean=1.0, std=0.1)
q_weight = q_norm.weight.data
k_weight = k_norm.weight.data
rotary_dim = int(head_dim * rotary_ratio)
rope = RotaryEmbedding(
head_size=head_dim,
rotary_dim=head_dim,
rotary_dim=rotary_dim,
max_position_embeddings=4096,
base=10000.0,
is_neox_style=is_neox,