[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)
This commit is contained in:
@@ -6,7 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
|
||||
MODELS = [
|
||||
@@ -35,7 +35,16 @@ def test_load_fp16_model(vllm_runner) -> None:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
fc1 = model.model.decoder.layers[0].fc1
|
||||
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability >= 89:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
else:
|
||||
# For GPUs without hardware support, we pack the fp8 weights
|
||||
# for weight-only quantization using Marlin kernels
|
||||
assert fc1.weight.dtype == torch.int32
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
@@ -63,7 +72,7 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
|
||||
|
||||
# Dynamic quantization
|
||||
ref_y, inv_scale = scaled_fp8_quant(x, None)
|
||||
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
|
||||
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
|
||||
|
||||
# Reference dynamic quantizaton
|
||||
@@ -71,11 +80,11 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||
|
||||
# Static quantization
|
||||
y, _ = scaled_fp8_quant(x, inv_scale)
|
||||
y, _ = ops.scaled_fp8_quant(x, inv_scale)
|
||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||
|
||||
# Padding
|
||||
y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
|
||||
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
|
||||
assert y.shape[0] == 17
|
||||
assert torch.allclose(
|
||||
ref_y,
|
||||
|
||||
Reference in New Issue
Block a user