[Bugfix][Hardware][AMD] Consolidate FP8 min/max values helper function (#31106)
Signed-off-by: c0de128 <kevin.mckay@outlook.com> Signed-off-by: Kevin McKay <kevin@example.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
65
tests/kernels/quantization/test_fp8_min_max_helper.py
Normal file
65
tests/kernels/quantization/test_fp8_min_max_helper.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for the get_fp8_min_max() helper function.
|
||||
|
||||
These tests verify the FP8 min/max value logic for both standard
|
||||
and fnuz (ROCm MI300) dtype handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_fp8_min_max,
|
||||
)
|
||||
|
||||
|
||||
class TestGetFp8MinMax:
|
||||
"""Test cases for get_fp8_min_max() function."""
|
||||
|
||||
@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
|
||||
def test_standard_fp8_platform(self, mock_platform):
|
||||
"""Test that standard FP8 platform uses PyTorch's finfo values."""
|
||||
mock_platform.is_fp8_fnuz.return_value = False
|
||||
mock_platform.fp8_dtype.return_value = torch.float8_e4m3fn
|
||||
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
|
||||
# Standard FP8 max is 448.0 for e4m3fn
|
||||
assert fp8_max == finfo.max, f"Expected finfo.max={finfo.max}, got {fp8_max}"
|
||||
assert fp8_min == finfo.min, f"Expected finfo.min={finfo.min}, got {fp8_min}"
|
||||
|
||||
@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
|
||||
def test_fnuz_platform_returns_224(self, mock_platform):
|
||||
"""Test that fnuz platform returns 224.0."""
|
||||
mock_platform.is_fp8_fnuz.return_value = True
|
||||
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
|
||||
# fnuz on ROCm MI300 should return 224.0, not 240.0
|
||||
assert fp8_max == 224.0, f"Expected 224.0 for fnuz platform, got {fp8_max}"
|
||||
assert fp8_min == -224.0, f"Expected -224.0 for fnuz platform, got {fp8_min}"
|
||||
|
||||
@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
|
||||
def test_non_fnuz_platform_uses_finfo(self, mock_platform):
|
||||
"""Test that non-fnuz platform uses finfo values."""
|
||||
mock_platform.is_fp8_fnuz.return_value = False
|
||||
mock_platform.fp8_dtype.return_value = torch.float8_e4m3fn
|
||||
|
||||
fp8_min, fp8_max = get_fp8_min_max()
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
|
||||
assert fp8_max == finfo.max, (
|
||||
f"Non-fnuz platform should use finfo.max={finfo.max}, got {fp8_max}"
|
||||
)
|
||||
assert fp8_min == finfo.min, (
|
||||
f"Non-fnuz platform should use finfo.min={finfo.min}, got {fp8_min}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user