[Misc/Testing] Use torch.testing.assert_close (#7324)

This commit is contained in:
jon-chuang
2024-08-15 21:24:04 -07:00
committed by GitHub
parent e165528778
commit 50b8d08dbd
25 changed files with 197 additions and 188 deletions

View File

@@ -74,7 +74,7 @@ def cutlass_fp8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
def cutlass_int8_gemm_helper(m: int,
@@ -106,7 +106,7 @@ def cutlass_int8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
@@ -252,7 +252,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 + azp_a)
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
@@ -271,8 +271,8 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
scale_b,
out_dtype=out_dtype,
bias=azp_bias[0, :])
assert torch.allclose(out, baseline_dq, rtol=1e-2, atol=1e0)
assert torch.allclose(out, baseline_q, rtol=1e-2, atol=1e0)
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
@pytest.mark.parametrize("m", [32, 64, 128])
@@ -302,7 +302,10 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
torch.testing.assert_close(a_dq,
scale_a * aq_f32 - azp_a,
rtol=1e-4,
atol=1e-3)
if use_bias:
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
@@ -335,8 +338,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
atol = 1e-3
assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
# Test working with a subset of A and B
@@ -363,7 +366,7 @@ def test_cutlass_subset():
scale_b,
out_dtype=torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
# Test to make sure cuda graphs work
@@ -411,4 +414,4 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)