[Bugfix][ROCm] Fix Static Quant Issue (#31502)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
Robert Shaw
2025-12-29 16:27:55 -05:00
committed by GitHub
parent 9152a30d8f
commit 56f516254c
2 changed files with 14 additions and 14 deletions

View File

@@ -325,8 +325,11 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): ):
# TODO(rob): rocm_aiter_fused_experts uses self.quant_config's
# a_scales for static quantization. Update this to fit better
# with the interface once all quant integrations are complete.
assert a1q_scale is None assert a1q_scale is None
assert a2_scale is None assert a2_scale == self.quant_config.a2_scale
assert expert_tokens_meta is None assert expert_tokens_meta is None
result = rocm_aiter_fused_experts( result = rocm_aiter_fused_experts(

View File

@@ -1046,35 +1046,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight = layer.w2_weight w2_weight = layer.w2_weight
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}") w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}") w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
w13_input_scale = layer.w13_input_scale
w2_input_scale = layer.w2_input_scale
# MI300x and MI325x use FNUZ format for FP8. Convert if needed. # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
w13_weight, w13_weight_scale, layer.w13_input_scale = ( w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
w13_weight, w13_weight_scale, layer.w13_input_scale w13_weight, w13_weight_scale, w13_input_scale
) )
) )
w2_weight, w2_weight_scale, layer.w2_input_scale = ( w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
normalize_e4m3fn_to_e4m3fnuz( w2_weight, w2_weight_scale, w2_input_scale
w2_weight, w2_weight_scale, layer.w2_input_scale
)
) )
# Per tensor kernels require single activation scale. Use the max. # Per tensor kernels require single activation scale. Use the max.
if self.quant_config.activation_scheme == "static": if self.quant_config.activation_scheme == "static":
assert not self.block_quant assert not self.block_quant
assert layer.w13_input_scale is not None assert w13_input_scale is not None and w2_input_scale is not None
assert layer.w2_input_scale is not None if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once( logger.warning_once(
"Found input_scales that are not equal for " "Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts " "fp8 MoE layer. Using the maximum across experts "
"for each layer." "for each layer."
) )
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max()) replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max()) replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
# Per tensor kernels require single weight scale for w13 per expert, but # Per tensor kernels require single weight scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize. # on disk there is a scale for w1 and w3. Use the max to requantize.