[Bugfix][Hardware][AMD] Consolidate FP8 min/max values helper function (#31106)
Signed-off-by: c0de128 <kevin.mckay@outlook.com> Signed-off-by: Kevin McKay <kevin@example.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -4,13 +4,13 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_fp8_min_max,
|
||||
group_broadcast,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||
ROCM_FP8FNUZ_MAX = 224.0
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
@@ -25,16 +25,12 @@ def ref_dynamic_per_token_quant(
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == FP8_DTYPE
|
||||
|
||||
qtype_traits = (
|
||||
torch.iinfo(quant_dtype)
|
||||
if quant_dtype == torch.int8
|
||||
else torch.finfo(quant_dtype)
|
||||
)
|
||||
use_fp8fnuz = (
|
||||
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
|
||||
)
|
||||
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
|
||||
if quant_dtype == torch.int8:
|
||||
qtype_traits = torch.iinfo(quant_dtype)
|
||||
qtype_traits_min = qtype_traits.min
|
||||
qtype_traits_max = qtype_traits.max
|
||||
else:
|
||||
qtype_traits_min, qtype_traits_max = get_fp8_min_max()
|
||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
s_512 = as_float32_tensor(512.0)
|
||||
@@ -72,17 +68,7 @@ def ref_dynamic_per_token_quant(
|
||||
def ref_dynamic_per_tensor_fp8_quant(
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = (
|
||||
ROCM_FP8FNUZ_MAX
|
||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
||||
else fp8_traits.max
|
||||
)
|
||||
fp8_traits_min = (
|
||||
-ROCM_FP8FNUZ_MAX
|
||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
||||
else fp8_traits.min
|
||||
)
|
||||
fp8_traits_min, fp8_traits_max = get_fp8_min_max()
|
||||
fp8_max = as_float32_tensor(fp8_traits_max)
|
||||
one = as_float32_tensor(1.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user