[ Misc ] fp8-marlin channelwise via compressed-tensors (#6524)
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -18,8 +18,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
|
||||
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
|
||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
@@ -179,19 +180,29 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
# If checkpoint is fp8, requantize the separately quantized logical
|
||||
# weights into a single fp8 weight with a single weight scale.
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
# Dequant -> Quant with max scale.
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
weight = layer.weight
|
||||
weight_scale = convert_to_channelwise(layer.weight_scale,
|
||||
layer.logical_widths)
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
else:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user