Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,18 +4,19 @@
|
||||
|
||||
Run `pytest tests/quantization/test_ptpc_fp8.py --forked`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.ptpc_fp8 import (
|
||||
PTPCFp8LinearMethod)
|
||||
from vllm.model_executor.layers.quantization.ptpc_fp8 import PTPCFp8LinearMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
UNSUPPORTED_STR = (
|
||||
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
|
||||
"support output dtype of bfloat16. torch.float16 is specified.")
|
||||
"support output dtype of bfloat16. torch.float16 is specified."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
@@ -24,18 +25,21 @@ def enable_pickle(monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
|
||||
reason="PTPC FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="This test is for ROCm GPU.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("ptpc_fp8"),
|
||||
reason="PTPC FP8 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="This test is for ROCm GPU.")
|
||||
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
|
||||
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
||||
try:
|
||||
llm = vllm_runner("facebook/opt-125m",
|
||||
dtype=dtype,
|
||||
quantization="ptpc_fp8",
|
||||
kv_cache_dtype=kv_cache_dtype)
|
||||
llm = vllm_runner(
|
||||
"facebook/opt-125m",
|
||||
dtype=dtype,
|
||||
quantization="ptpc_fp8",
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
except AssertionError as e:
|
||||
if str(e) == UNSUPPORTED_STR:
|
||||
# If the error message matches, the test passes
|
||||
|
||||
Reference in New Issue
Block a user