[CI][AMD][Quantization][BugFix] Fix fp8 max in quant_utils.py and update test_fp8_quant.::test_static_fp8_quant_group_2d to use correct fp8 dtype and adjust atol/rtol (#32201)

Signed-off-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith
2026-01-14 23:04:34 -06:00
committed by GitHub
parent 773d7073ae
commit 3c2685645e
2 changed files with 4 additions and 3 deletions

View File

@@ -178,12 +178,12 @@ def test_static_fp8_quant_group_2d(
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale = scaled_quantize(
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
x, group_shape, current_platform.fp8_dtype(), compute_dtype=torch.float32
)
ops_out, ops_scale = ops.scaled_fp8_quant(x, scale=scale, group_shape=group_shape)
torch.testing.assert_close(scale, ops_scale)
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=1.2e-1, atol=1e-5)
opcheck_fp8_quant(ops_out, x, scale=scale)