[ Misc ] Support Fp8 via llm-compressor (#6110)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -11,11 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
|
||||
marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
pack_fp8_to_int32)
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
|
||||
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
@@ -25,13 +25,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@@ -117,23 +110,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89
|
||||
|
||||
def _create_scale_param(
|
||||
self,
|
||||
scale_name: str,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
scale = Parameter(torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
scale[:] = torch.finfo(torch.float8_e4m3fn).min
|
||||
layer.register_parameter(scale_name, scale)
|
||||
set_weight_attrs(scale, {
|
||||
**extra_weight_attrs,
|
||||
"needs_scalar_to_array": True,
|
||||
})
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -147,7 +123,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
layer.process_after_load = True
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
@@ -173,144 +148,50 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
self._create_scale_param(
|
||||
scale_name="weight_scale",
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
self._create_scale_param(
|
||||
scale_name="input_scale",
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
|
||||
# For GPUs without FP8 hardware support, we use Marlin for fast
|
||||
# fused dequantization
|
||||
if self.use_marlin:
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
|
||||
def prepare_layer_for_marlin(self, layer: Module) -> None:
|
||||
print_warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
assert layer.marlin_state == GPTQMarlinState.REPACK
|
||||
layer.marlin_state = GPTQMarlinState.READY
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WEIGHTS
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
||||
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight = Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||
layer.orig_dtype).to(device)
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
max_workspace_size = (
|
||||
part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
layer.workspace = workspace
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if (not hasattr(layer, "process_after_load")
|
||||
or not layer.process_after_load):
|
||||
return
|
||||
|
||||
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
scale=None)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.logical_widths = None
|
||||
layer.input_scale = None
|
||||
if self.use_marlin:
|
||||
self.prepare_layer_for_marlin(layer)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, requantize the separately quantized logical
|
||||
# weights into a single fp8 weight with a single weight scale.
|
||||
else:
|
||||
# WEIGHT_SCALE / WEIGHT
|
||||
# Loop over logical weights, requantizing with single scale.
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
# Dequant -> Quant with max scale.
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
# QKV / MLP is fused in the on disk checkpoint if any of the
|
||||
# weight scales are still set to the default since we initialize
|
||||
# N weight scales for N shards but we only load 1 weight scale
|
||||
# from disk in this case. As a result, we skip dequant -> requant
|
||||
# since we already have quantized QKV together.
|
||||
# Sample Model with fused checkpoint:
|
||||
# * nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||
unfused_module_in_checkpoint = (
|
||||
layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)
|
||||
|
||||
if unfused_module_in_checkpoint:
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(layer.logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(
|
||||
layer.weight[start:end, :], layer.weight_scale[idx])
|
||||
|
||||
layer.weight[start:end, :] = per_tensor_quantize(
|
||||
weight_dq, layer.weight_scale.max())
|
||||
start = end
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# WEIGHT
|
||||
# Transpose weight for passing to torch._scaled_mm
|
||||
weight = layer.weight
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
||||
# Static: set to max of the input_scales (since they are equal).
|
||||
if self.quant_config.activation_scheme == "dynamic":
|
||||
layer.input_scale = None
|
||||
elif self.quant_config.activation_scheme == "static":
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown scheme {self.quant_config.activation_scheme}")
|
||||
layer.input_scale = None
|
||||
|
||||
if self.use_marlin:
|
||||
self.prepare_layer_for_marlin(layer)
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -318,65 +199,22 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=layer.weight,
|
||||
b_scales=layer.weight_scale,
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
)
|
||||
bias=bias)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
else:
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale
|
||||
|
||||
if bias is None and self.cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
)
|
||||
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||
layer.input_scale,
|
||||
batch_dim_padding=17)
|
||||
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
# torch._scaled_mm is more performant for matrices with
|
||||
# batch dimension > 16. Note that this could change
|
||||
# in the future.
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return torch.narrow(output, 0, 0, x.shape[0])
|
||||
return apply_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
@@ -399,8 +237,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
|
||||
layer.process_after_load = True
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
@@ -465,9 +301,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.a2_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if (not hasattr(layer, "process_after_load")
|
||||
or not layer.process_after_load):
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
@@ -531,7 +364,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
shard_size, :],
|
||||
layer.w13_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :] = per_tensor_quantize(
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
@@ -596,23 +429,3 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||
"cause accuracy issues. Please make sure kv-cache scaling "
|
||||
"factor is available in the fp8 checkpoint.")
|
||||
del layer.kv_scale
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor,
|
||||
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return qweight.to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def per_tensor_dequantize(
|
||||
tensor: torch.Tensor, inv_scale: Union[float,
|
||||
torch.Tensor]) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(torch.float16)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
assert len(x.shape) == 1
|
||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||
|
||||
Reference in New Issue
Block a user