diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 7606c2a91..2564f1829 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -155,9 +155,9 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS) if xnorm: - assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8) + torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8) else: - assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2) + torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-2) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) @@ -177,7 +177,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): ref_out = torch.matmul(A, B.t()) out = ops.LLMM1(B, A, rows_per_block) - assert torch.allclose(out, ref_out, rtol=0.01) + torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @@ -194,7 +194,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): ref_out = torch.nn.functional.linear(A, B) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count) - assert torch.allclose(out, ref_out, rtol=0.01) + torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @@ -213,7 +213,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): ref_out = torch.nn.functional.linear(A, B, BIAS) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) - assert torch.allclose(out, ref_out, rtol=0.01) + torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @@ -232,7 +232,7 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): ref_out = torch.nn.functional.linear(A, B, BIAS) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) - assert torch.allclose(out, ref_out, rtol=0.01) + torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) @pytest.mark.parametrize("xnorm", [False, True]) @@ -275,4 +275,4 @@ def test_rocm_wvsplitk_fp8_kernel( # wider pytrch thresh for large-K & no xnorm torch.testing.assert_close(out, ref_out, atol=0.07, rtol=5e-2) else: - torch.testing.assert_close(out, ref_out, atol=0.01, rtol=0.01) + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)