[Bugfix] Fix bugs of running Quark quantized models (#16236)
Signed-off-by: chaow <chaow@amd.com>
This commit is contained in:
@@ -21,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
@@ -41,10 +41,11 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
input_scale=layer.input_scale)
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
@@ -57,11 +58,12 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale)
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
@@ -105,7 +107,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
# the newly added parameters
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
@@ -35,7 +35,7 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
self.logical_widths = output_partition_sizes
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
@@ -63,16 +63,28 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
# WEIGHT SCALE
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
ChannelQuantZPParameter = ChannelQuantScaleParameter
|
||||
weight_zero_point = ChannelQuantZPParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.int8),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
PerTensorZPParameter = PerTensorScaleParameter
|
||||
weight_zero_point = PerTensorZPParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
@@ -81,14 +93,10 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: quark stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
input_zero_point = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
@@ -100,6 +108,12 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.register_parameter("weight_zero_point", None)
|
||||
delattr(layer, 'weight_zero_point')
|
||||
if self.input_symmetric:
|
||||
layer.register_parameter("input_zero_point", None)
|
||||
delattr(layer, 'input_zero_point')
|
||||
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user