[Bugfix][Hardware][AMD] Consolidate FP8 min/max values helper function (#31106)

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Signed-off-by: Kevin McKay <kevin@example.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Kevin McKay
2026-01-07 00:55:03 -06:00
committed by GitHub
parent 482914849c
commit 4614c5a539
6 changed files with 102 additions and 39 deletions

View File

@@ -4,13 +4,13 @@
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
group_broadcast,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8FNUZ_MAX = 224.0
FP8_DTYPE = current_platform.fp8_dtype()
@@ -25,16 +25,12 @@ def ref_dynamic_per_token_quant(
if scale_ub is not None:
assert quant_dtype == FP8_DTYPE
qtype_traits = (
torch.iinfo(quant_dtype)
if quant_dtype == torch.int8
else torch.finfo(quant_dtype)
)
use_fp8fnuz = (
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
)
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
if quant_dtype == torch.int8:
qtype_traits = torch.iinfo(quant_dtype)
qtype_traits_min = qtype_traits.min
qtype_traits_max = qtype_traits.max
else:
qtype_traits_min, qtype_traits_max = get_fp8_min_max()
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
@@ -72,17 +68,7 @@ def ref_dynamic_per_token_quant(
def ref_dynamic_per_tensor_fp8_quant(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = (
ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.max
)
fp8_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.min
)
fp8_traits_min, fp8_traits_max = get_fp8_min_max()
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)