[Misc][Breaking] Change FP8 checkpoint format from act_scale -> input_scale (#5353)
This commit is contained in:
@@ -171,10 +171,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
|
||||
# ACTIVATION SCALE
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
self._create_scale_param(
|
||||
scale_name="act_scale",
|
||||
scale_name="input_scale",
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
@@ -207,7 +207,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.logical_widths = None
|
||||
layer.act_scale = None
|
||||
layer.input_scale = None
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, requantize the separately quantized logical
|
||||
@@ -232,18 +232,18 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
# ACT_SCALE
|
||||
# INPUT ACTIVATION SCALE
|
||||
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
||||
# Static: set to max of the act_scales (since they are equal).
|
||||
# Static: set to max of the input_scales (since they are equal).
|
||||
if self.quant_config.activation_scheme == "dynamic":
|
||||
layer.act_scale = None
|
||||
layer.input_scale = None
|
||||
elif self.quant_config.activation_scheme == "static":
|
||||
if not all_close_1d(layer.act_scale):
|
||||
if not all_close_1d(layer.input_scale):
|
||||
raise ValueError(
|
||||
"All the act_scales for the logical weights of a layer "
|
||||
f"must be equal. But got {layer.act_scale}")
|
||||
layer.act_scale = Parameter(layer.act_scale.max(),
|
||||
requires_grad=False)
|
||||
"All the input_scales for the logical weights of a "
|
||||
f"layer must be equal. But got {layer.input_scale}")
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown scheme {self.quant_config.activation_scheme}")
|
||||
@@ -254,11 +254,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
||||
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
|
||||
if bias is None and self.cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm_dq(
|
||||
@@ -271,7 +271,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||
layer.act_scale,
|
||||
layer.input_scale,
|
||||
batch_dim_padding=17)
|
||||
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
|
||||
Reference in New Issue
Block a user