[CI/Build][Kernel][BugFix][AMD] Fix per_token_group_quant_fp8 to use correct fp8 min/max values and update atol/rtol in test_quantfp8_group_functionality (#30292)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -762,9 +762,12 @@ def per_token_group_quant_fp8(
|
||||
)
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
|
||||
# platforms that use the torch.float8_e4mefnuz dtype.
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
|
||||
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
|
||||
|
||||
assert out_q is None or out_q.shape == x.shape
|
||||
x_q = out_q
|
||||
|
||||
Reference in New Issue
Block a user