[Misc/Testing] Use torch.testing.assert_close (#7324)

This commit is contained in:
jon-chuang
2024-08-15 21:24:04 -07:00
committed by GitHub
parent e165528778
commit 50b8d08dbd
25 changed files with 197 additions and 188 deletions

View File

@@ -127,16 +127,18 @@ def test_scaled_fp8_quant(dtype) -> None:
# Reference dynamic quantizaton
y = quantize_ref(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
torch.testing.assert_close(ref_y,
per_tensor_dequantize(y, inv_scale, dtype))
# Static quantization
y, _ = ops.scaled_fp8_quant(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
torch.testing.assert_close(ref_y,
per_tensor_dequantize(y, inv_scale, dtype))
# Padding
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
assert y.shape[0] == 17
assert torch.allclose(
torch.testing.assert_close(
ref_y,
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
dtype))