[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)

This commit is contained in:
Michael Goin
2024-07-03 13:38:00 -04:00
committed by GitHub
parent 7cd2ebb025
commit 47f0954af0
11 changed files with 1585 additions and 42 deletions

View File

@@ -11,6 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
pack_fp8_to_int32)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
@@ -54,7 +59,7 @@ class Fp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 89
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
@@ -106,6 +111,12 @@ class Fp8LinearMethod(LinearMethodBase):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
def _create_scale_param(
self,
scale_name: str,
@@ -139,6 +150,10 @@ class Fp8LinearMethod(LinearMethodBase):
layer.process_after_load = True
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
@@ -172,6 +187,65 @@ class Fp8LinearMethod(LinearMethodBase):
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
# For GPUs without FP8 hardware support, we use Marlin for fast
# fused dequantization
if self.use_marlin:
layer.marlin_state = GPTQMarlinState.REPACK
def prepare_layer_for_marlin(self, layer: Module) -> None:
print_warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
assert layer.marlin_state == GPTQMarlinState.REPACK
layer.marlin_state = GPTQMarlinState.READY
device = layer.weight.device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=part_size_k,
size_n=part_size_n,
num_bits=8,
)
layer.weight = Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to(
layer.orig_dtype).to(device)
# Permute scales
marlin_scales = marlin_permute_scales(
s=scales,
size_k=part_size_k,
size_n=part_size_n,
group_size=-1,
num_bits=8,
)
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
# Allocate marlin workspace
max_workspace_size = (
part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = workspace
def process_weights_after_loading(self, layer: Module) -> None:
if (not hasattr(layer, "process_after_load")
or not layer.process_after_load):
@@ -185,6 +259,8 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.input_scale = None
if self.use_marlin:
self.prepare_layer_for_marlin(layer)
return
# If checkpoint is fp8, requantize the separately quantized logical
@@ -233,44 +309,72 @@ class Fp8LinearMethod(LinearMethodBase):
raise ValueError(
f"Unknown scheme {self.quant_config.activation_scheme}")
if self.use_marlin:
self.prepare_layer_for_marlin(layer)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if self.use_marlin:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
if bias is None and self.cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=layer.weight,
b_scales=layer.weight_scale,
workspace=layer.workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
else:
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.input_scale,
batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x
# If static, layer.input_scale is scalar and x_scale is input_scale
if bias is None and self.cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.input_scale,
batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)
return torch.narrow(output, 0, 0, x.shape[0])