[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -127,16 +127,18 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
|
||||
# Reference dynamic quantizaton
|
||||
y = quantize_ref(x, inv_scale)
|
||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||
torch.testing.assert_close(ref_y,
|
||||
per_tensor_dequantize(y, inv_scale, dtype))
|
||||
|
||||
# Static quantization
|
||||
y, _ = ops.scaled_fp8_quant(x, inv_scale)
|
||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||
torch.testing.assert_close(ref_y,
|
||||
per_tensor_dequantize(y, inv_scale, dtype))
|
||||
|
||||
# Padding
|
||||
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
|
||||
assert y.shape[0] == 17
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
ref_y,
|
||||
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
|
||||
dtype))
|
||||
|
||||
Reference in New Issue
Block a user