[Bugfix][Hardware][AMD] Consolidate FP8 min/max values helper function (#31106)

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Signed-off-by: Kevin McKay <kevin@example.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Kevin McKay
2026-01-07 00:55:03 -06:00
committed by GitHub
parent 482914849c
commit 4614c5a539
6 changed files with 102 additions and 39 deletions

View File

@@ -19,6 +19,17 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def get_fp8_min_max() -> tuple[float, float]:
"""Get the min and max values for FP8 quantization."""
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models on ROCm. Here, use 224.0 for fnuz
# on ROCm platforms that use the torch.float8_e4m3fnuz dtype.
if current_platform.is_fp8_fnuz():
return -224.0, 224.0
finfo = torch.finfo(current_platform.fp8_dtype())
return finfo.min, finfo.max
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int