[CI][AMD][Quantization][BugFix] Fix fp8 max in quant_utils.py and update test_fp8_quant.::test_static_fp8_quant_group_2d to use correct fp8 dtype and adjust atol/rtol (#32201)
Signed-off-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -221,7 +221,8 @@ def scaled_quantize(
|
||||
# Compute scales
|
||||
min_val, max_val = x_blkd_permd.aminmax(dim=-1)
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax
|
||||
_, fp8_max = get_fp8_min_max()
|
||||
scale = fp8_max / amax
|
||||
|
||||
# Apply scale and convert form:
|
||||
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
|
||||
|
||||
Reference in New Issue
Block a user