[ Misc ] Support Fp8 via llm-compressor (#6110)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
Robert Shaw
2024-07-07 16:42:11 -04:00
committed by GitHub
parent 333306a252
commit abfe705a02
17 changed files with 603 additions and 372 deletions

View File

@@ -1,9 +1,11 @@
"""This file is used for /tests and /benchmarks"""
import random
from typing import Optional
import numpy
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.format_24 import (
mask_creator, sparse_semi_structured_from_dense_cutlass)
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
@@ -13,8 +15,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor, quantize_weights, sort_weights)
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
MARLIN_TILE = 16
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
def is_marlin_supported():
@@ -22,7 +32,92 @@ def is_marlin_supported():
return capability[0] >= 8
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n, )
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=weight,
b_scales=weight_scale,
workspace=workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=size_n,
size_k=size_k,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
print_warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
device = layer.weight.device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=part_size_k,
size_n=part_size_n,
num_bits=8,
)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to(
layer.orig_dtype).to(device)
# Permute scales
num_bits = 8
marlin_scales = marlin_permute_scales(
s=scales,
size_k=part_size_k,
size_n=part_size_n,
group_size=-1,
scale_perm=marlin_scale_perm[num_bits],
scale_perm_single=marlin_scale_perm_single[num_bits])
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
# Allocate marlin workspace
max_workspace_size = (part_size_n //
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = workspace
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"