[CI][AMD][Quantization][BugFix] Fix fp8 max in quant_utils.py and update test_fp8_quant.::test_static_fp8_quant_group_2d to use correct fp8 dtype and adjust atol/rtol (#32201)
Signed-off-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -178,12 +178,12 @@ def test_static_fp8_quant_group_2d(
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale = scaled_quantize(
|
||||
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
|
||||
x, group_shape, current_platform.fp8_dtype(), compute_dtype=torch.float32
|
||||
)
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(x, scale=scale, group_shape=group_shape)
|
||||
|
||||
torch.testing.assert_close(scale, ops_scale)
|
||||
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
|
||||
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=1.2e-1, atol=1e-5)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x, scale=scale)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user