[Bugfix][Hardware][AMD] Consolidate FP8 min/max values helper function (#31106)

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Signed-off-by: Kevin McKay <kevin@example.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Kevin McKay
2026-01-07 00:55:03 -06:00
committed by GitHub
parent 482914849c
commit 4614c5a539
6 changed files with 102 additions and 39 deletions

View File

@@ -7,15 +7,14 @@ import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
)
from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
_FP8_DTYPE = current_platform.fp8_dtype()
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
_FP8_MIN, _FP8_MAX = get_fp8_min_max()
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)

View File

@@ -15,7 +15,10 @@ from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
)
@@ -748,12 +751,7 @@ def per_token_group_quant_fp8(
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
# platforms that use the torch.float8_e4mefnuz dtype.
finfo = torch.finfo(dtype)
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
fp8_min, fp8_max = get_fp8_min_max()
assert out_q is None or out_q.shape == x.shape
x_q = out_q

View File

@@ -19,6 +19,17 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def get_fp8_min_max() -> tuple[float, float]:
"""Get the min and max values for FP8 quantization."""
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models on ROCm. Here, use 224.0 for fnuz
# on ROCm platforms that use the torch.float8_e4m3fnuz dtype.
if current_platform.is_fp8_fnuz():
return -224.0, 224.0
finfo = torch.finfo(current_platform.fp8_dtype())
return finfo.min, finfo.max
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int