[Bugfix] Fix quark fp8 format loading on AMD GPUs (#12612)

Signed-off-by: Felix Marty <felmarty@amd.com>
Signed-off-by: kewang2 <kewang2@amd.com>
Co-authored-by: kewang2 <kewang2@amd.com>
This commit is contained in:
fxmarty-amd
2025-05-08 11:53:53 +02:00
committed by GitHub
parent a463555dee
commit bb239a730f
2 changed files with 38 additions and 9 deletions

View File

@@ -34,21 +34,24 @@ class QuarkW8A8Fp8(QuarkScheme):
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.qscheme == "per_tensor":
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
if current_platform.is_fp8_fnuz():
if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
max_w_scale = layer.weight_scale
weight = layer.weight
max_w_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=max_w_scale,
logical_widths=layer.logical_widths,
)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)