[Bugfix][ROCm] Fix Static Quant Issue (#31502)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -1046,35 +1046,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight = layer.w2_weight
|
||||
w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
w13_input_scale = layer.w13_input_scale
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
|
||||
if current_platform.is_fp8_fnuz():
|
||||
w13_weight, w13_weight_scale, layer.w13_input_scale = (
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
w13_weight, w13_weight_scale, layer.w13_input_scale
|
||||
w13_weight, w13_weight_scale, w13_input_scale
|
||||
)
|
||||
)
|
||||
w2_weight, w2_weight_scale, layer.w2_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
w2_weight, w2_weight_scale, layer.w2_input_scale
|
||||
)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w2_weight, w2_weight_scale, w2_input_scale
|
||||
)
|
||||
|
||||
# Per tensor kernels require single activation scale. Use the max.
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
assert not self.block_quant
|
||||
assert layer.w13_input_scale is not None
|
||||
assert layer.w2_input_scale is not None
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
assert w13_input_scale is not None and w2_input_scale is not None
|
||||
if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max())
|
||||
replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max())
|
||||
replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
|
||||
replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
|
||||
|
||||
# Per tensor kernels require single weight scale for w13 per expert, but
|
||||
# on disk there is a scale for w1 and w3. Use the max to requantize.
|
||||
|
||||
Reference in New Issue
Block a user