[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -113,8 +113,8 @@ class QKNormRoPETestModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="Only test on cuda platform",
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Only test on cuda and rocm platform",
|
||||
)
|
||||
def test_qk_norm_rope_fusion(
|
||||
eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype
|
||||
|
||||
Reference in New Issue
Block a user