[ Misc ] fp8-marlin channelwise via compressed-tensors (#6524)

Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Robert Shaw
2024-07-25 09:46:04 -07:00
committed by GitHub
parent b75e314fff
commit 889da130e7
13 changed files with 219 additions and 49 deletions

View File

@@ -18,8 +18,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
all_close_1d, apply_fp8_linear, convert_to_channelwise,
create_per_tensor_scale_param, cutlass_fp8_supported,
per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
@@ -179,19 +180,29 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
# Dequant -> Quant with max scale.
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)