[Kernel] Enable fused_qknorm_rope_kernel supports partial rope (#30821)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -13,6 +13,7 @@ DTYPES = [torch.bfloat16, torch.float16]
|
||||
IS_NEOX = [True, False]
|
||||
EPS_VALUES = [1e-5, 1e-6]
|
||||
SEEDS = [13]
|
||||
PARTIAL_ROPE = [True, False]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
@@ -52,6 +53,7 @@ def _apply_qk_norm_rope(
|
||||
@pytest.mark.parametrize("is_neox", IS_NEOX)
|
||||
@pytest.mark.parametrize("eps", EPS_VALUES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("rotary_ratio", [1.0, 0.5, 0.25])
|
||||
@torch.inference_mode()
|
||||
def test_fused_qk_norm_rope_matches_reference(
|
||||
device: str,
|
||||
@@ -59,6 +61,7 @@ def test_fused_qk_norm_rope_matches_reference(
|
||||
is_neox: bool,
|
||||
eps: float,
|
||||
seed: int,
|
||||
rotary_ratio: float,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
@@ -76,10 +79,10 @@ def test_fused_qk_norm_rope_matches_reference(
|
||||
k_norm.weight.data.normal_(mean=1.0, std=0.1)
|
||||
q_weight = q_norm.weight.data
|
||||
k_weight = k_norm.weight.data
|
||||
|
||||
rotary_dim = int(head_dim * rotary_ratio)
|
||||
rope = RotaryEmbedding(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position_embeddings=4096,
|
||||
base=10000.0,
|
||||
is_neox_style=is_neox,
|
||||
|
||||
Reference in New Issue
Block a user