[Bugfix][Hardware][AMD] Consolidate FP8 min/max values helper function (#31106)
Signed-off-by: c0de128 <kevin.mckay@outlook.com> Signed-off-by: Kevin McKay <kevin@example.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -19,6 +19,17 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def get_fp8_min_max() -> tuple[float, float]:
|
||||
"""Get the min and max values for FP8 quantization."""
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models on ROCm. Here, use 224.0 for fnuz
|
||||
# on ROCm platforms that use the torch.float8_e4m3fnuz dtype.
|
||||
if current_platform.is_fp8_fnuz():
|
||||
return -224.0, 224.0
|
||||
finfo = torch.finfo(current_platform.fp8_dtype())
|
||||
return finfo.min, finfo.max
|
||||
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
row: int
|
||||
|
||||
Reference in New Issue
Block a user