[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -74,7 +74,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
|
||||
|
||||
|
||||
def cutlass_int8_gemm_helper(m: int,
|
||||
@@ -106,7 +106,7 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
|
||||
@@ -252,7 +252,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
|
||||
assert torch.allclose(a_dq, scale_a * aq_f32 + azp_a)
|
||||
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
|
||||
|
||||
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
|
||||
|
||||
@@ -271,8 +271,8 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=azp_bias[0, :])
|
||||
assert torch.allclose(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||
assert torch.allclose(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@@ -302,7 +302,10 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
||||
assert torch.allclose(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
|
||||
torch.testing.assert_close(a_dq,
|
||||
scale_a * aq_f32 - azp_a,
|
||||
rtol=1e-4,
|
||||
atol=1e-3)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
||||
@@ -335,8 +338,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
||||
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
|
||||
atol = 1e-3
|
||||
assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
@@ -363,7 +366,7 @@ def test_cutlass_subset():
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
@@ -411,4 +414,4 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
|
||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
Reference in New Issue
Block a user