diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 566cb0239..7606c2a91 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -270,6 +270,9 @@ def test_rocm_wvsplitk_fp8_kernel( out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS) if xnorm: - assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8) + torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8) + elif k >= 32 * 1024: + # wider pytrch thresh for large-K & no xnorm + torch.testing.assert_close(out, ref_out, atol=0.07, rtol=5e-2) else: - assert torch.allclose(out, ref_out, 0.01) + torch.testing.assert_close(out, ref_out, atol=0.01, rtol=0.01)