[CI][AMD][Quantization][BugFix] Fix fp8 max in quant_utils.py and update test_fp8_quant.::test_static_fp8_quant_group_2d to use correct fp8 dtype and adjust atol/rtol (#32201)

Signed-off-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith
2026-01-14 23:04:34 -06:00
committed by GitHub
parent 773d7073ae
commit 3c2685645e
2 changed files with 4 additions and 3 deletions

View File

@@ -221,7 +221,8 @@ def scaled_quantize(
# Compute scales
min_val, max_val = x_blkd_permd.aminmax(dim=-1)
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
_, fp8_max = get_fp8_min_max()
scale = fp8_max / amax
# Apply scale and convert form:
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)