[Bugfix][ROCm] Fix Static Quant Issue (#31502)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -325,8 +325,11 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
):
|
):
|
||||||
|
# TODO(rob): rocm_aiter_fused_experts uses self.quant_config's
|
||||||
|
# a_scales for static quantization. Update this to fit better
|
||||||
|
# with the interface once all quant integrations are complete.
|
||||||
assert a1q_scale is None
|
assert a1q_scale is None
|
||||||
assert a2_scale is None
|
assert a2_scale == self.quant_config.a2_scale
|
||||||
assert expert_tokens_meta is None
|
assert expert_tokens_meta is None
|
||||||
|
|
||||||
result = rocm_aiter_fused_experts(
|
result = rocm_aiter_fused_experts(
|
||||||
|
|||||||
@@ -1046,35 +1046,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
w2_weight = layer.w2_weight
|
w2_weight = layer.w2_weight
|
||||||
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||||
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||||
|
w13_input_scale = layer.w13_input_scale
|
||||||
|
w2_input_scale = layer.w2_input_scale
|
||||||
|
|
||||||
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
|
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
|
||||||
if current_platform.is_fp8_fnuz():
|
if current_platform.is_fp8_fnuz():
|
||||||
w13_weight, w13_weight_scale, layer.w13_input_scale = (
|
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
w13_weight, w13_weight_scale, layer.w13_input_scale
|
w13_weight, w13_weight_scale, w13_input_scale
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
w2_weight, w2_weight_scale, layer.w2_input_scale = (
|
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
w2_weight, w2_weight_scale, w2_input_scale
|
||||||
w2_weight, w2_weight_scale, layer.w2_input_scale
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Per tensor kernels require single activation scale. Use the max.
|
# Per tensor kernels require single activation scale. Use the max.
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
assert not self.block_quant
|
assert not self.block_quant
|
||||||
assert layer.w13_input_scale is not None
|
assert w13_input_scale is not None and w2_input_scale is not None
|
||||||
assert layer.w2_input_scale is not None
|
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
|
||||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
|
||||||
layer.w2_input_scale
|
|
||||||
):
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Found input_scales that are not equal for "
|
"Found input_scales that are not equal for "
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
"for each layer."
|
"for each layer."
|
||||||
)
|
)
|
||||||
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
|
replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
|
||||||
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
|
replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
|
||||||
|
|
||||||
# Per tensor kernels require single weight scale for w13 per expert, but
|
# Per tensor kernels require single weight scale for w13 per expert, but
|
||||||
# on disk there is a scale for w1 and w3. Use the max to requantize.
|
# on disk there is a scale for w1 and w3. Use the max to requantize.
|
||||||
|
|||||||
Reference in New Issue
Block a user