[ Misc ] Support Fp8 via llm-compressor (#6110)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.format_24 import (
|
||||
mask_creator, sparse_semi_structured_from_dense_cutlass)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
|
||||
@@ -13,8 +15,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_pack_factor, quantize_weights, sort_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
|
||||
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
||||
|
||||
|
||||
def is_marlin_supported():
|
||||
@@ -22,7 +32,92 @@ def is_marlin_supported():
|
||||
return capability[0] >= 8
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
|
||||
def apply_fp8_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
workspace=workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
print_warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WEIGHTS
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
||||
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||
layer.orig_dtype).to(device)
|
||||
# Permute scales
|
||||
num_bits = 8
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1,
|
||||
scale_perm=marlin_scale_perm[num_bits],
|
||||
scale_perm_single=marlin_scale_perm_single[num_bits])
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
max_workspace_size = (part_size_n //
|
||||
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
layer.workspace = workspace
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
|
||||
Reference in New Issue
Block a user