[Misc] update fp8 to use vLLMParameter (#7437)
This commit is contained in:
@@ -19,9 +19,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
@@ -137,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
):
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
@@ -148,34 +150,41 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
params_dtype)
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
requires_grad=False)
|
||||
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {
|
||||
**extra_weight_attrs,
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
@@ -197,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
|
||||
requires_grad=False)
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
|
||||
Reference in New Issue
Block a user