[Bugfix][ROCm] Fix Static Quant Issue (#31502)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Robert Shaw
2025-12-29 16:27:55 -05:00
committed by GitHub
parent 9152a30d8f
commit 56f516254c
2 changed files with 14 additions and 14 deletions

View File

@@ -1046,35 +1046,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight = layer.w2_weight
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
w13_input_scale = layer.w13_input_scale
w2_input_scale = layer.w2_input_scale
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz():
w13_weight, w13_weight_scale, layer.w13_input_scale = (
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
w13_weight, w13_weight_scale, layer.w13_input_scale
w13_weight, w13_weight_scale, w13_input_scale
)
)
w2_weight, w2_weight_scale, layer.w2_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
w2_weight, w2_weight_scale, layer.w2_input_scale
)
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
w2_weight, w2_weight_scale, w2_input_scale
)
# Per tensor kernels require single activation scale. Use the max.
if self.quant_config.activation_scheme == "static":
assert not self.block_quant
assert layer.w13_input_scale is not None
assert layer.w2_input_scale is not None
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
assert w13_input_scale is not None and w2_input_scale is not None
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
# Per tensor kernels require single weight scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.