[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-11-12 05:01:14 -08:00
committed by GitHub
parent c5f10cc139
commit edb59a9470
6 changed files with 37 additions and 38 deletions

View File

@@ -113,8 +113,8 @@ class QKNormRoPETestModel(torch.nn.Module):
@pytest.mark.parametrize("enable_rope_custom_op", [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Only test on cuda platform",
not current_platform.is_cuda_alike(),
reason="Only test on cuda and rocm platform",
)
def test_qk_norm_rope_fusion(
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype