[Bugfix] Fix bugs of running Quark quantized models (#16236)

Signed-off-by: chaow <chaow@amd.com>
This commit is contained in:
chaow-amd
2025-04-11 22:18:32 +08:00
committed by GitHub
parent e9528f6dc6
commit 9e90c9f73f
3 changed files with 67 additions and 22 deletions

View File

@@ -21,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.out_dtype = torch.get_default_dtype()
@classmethod
@@ -41,10 +41,11 @@ class QuarkW8A8Fp8(QuarkScheme):
)
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
@@ -57,11 +58,12 @@ class QuarkW8A8Fp8(QuarkScheme):
weight = layer.weight
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
@@ -105,7 +107,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# the newly added parameters
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)

View File

@@ -35,7 +35,7 @@ class QuarkW8A8Int8(QuarkScheme):
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"),
@@ -63,16 +63,28 @@ class QuarkW8A8Int8(QuarkScheme):
# WEIGHT SCALE
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
ChannelQuantZPParameter = ChannelQuantScaleParameter
weight_zero_point = ChannelQuantZPParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.int8),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
PerTensorZPParameter = PerTensorScaleParameter
weight_zero_point = PerTensorZPParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE
if self.is_static_input_scheme:
@@ -81,14 +93,10 @@ class QuarkW8A8Int8(QuarkScheme):
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: quark stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
input_zero_point = BasevLLMParameter(data=torch.empty(
1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
@@ -100,6 +108,12 @@ class QuarkW8A8Int8(QuarkScheme):
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.register_parameter("weight_zero_point", None)
delattr(layer, 'weight_zero_point')
if self.input_symmetric:
layer.register_parameter("input_zero_point", None)
delattr(layer, 'input_zero_point')
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,