[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)

This commit is contained in:
Michael Goin
2024-07-03 13:38:00 -04:00
committed by GitHub
parent 7cd2ebb025
commit 47f0954af0
11 changed files with 1585 additions and 42 deletions

View File

@@ -6,7 +6,7 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm._custom_ops import scaled_fp8_quant
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
MODELS = [
@@ -35,7 +35,16 @@ def test_load_fp16_model(vllm_runner) -> None:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability >= 89:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
@@ -63,7 +72,7 @@ def test_scaled_fp8_quant(dtype) -> None:
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
# Dynamic quantization
ref_y, inv_scale = scaled_fp8_quant(x, None)
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
# Reference dynamic quantizaton
@@ -71,11 +80,11 @@ def test_scaled_fp8_quant(dtype) -> None:
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
# Static quantization
y, _ = scaled_fp8_quant(x, inv_scale)
y, _ = ops.scaled_fp8_quant(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
# Padding
y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
assert y.shape[0] == 17
assert torch.allclose(
ref_y,