[MoE Refactor][12/N] Marlin Fp8 MoE Pure Function (#31499)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-12-29 16:27:00 -05:00
committed by GitHub
parent c2ff33cc8c
commit 9152a30d8f
4 changed files with 92 additions and 76 deletions

View File

@@ -912,6 +912,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
(
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
) = prepare_moe_fp8_layer_for_marlin(
layer,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
input_dtype=self.marlin_input_dtype,
)
layer.workspace = workspace
elif self.fp8_backend in [
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
@@ -937,17 +954,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
# TODO(rob): we do this after replace_parameter() because
# prepare_moe_fp8_layer_for_marlin uses on the layer's params
# directly. We will refactor this in a follow up PR.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
def _setup_kernel(self, layer: Module) -> None:
"""Setup Modular Kernel for TP Case"""
# NOTE(rob): this is a WIP refactor. We are first migrating
@@ -1194,20 +1200,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) -> FusedMoEQuantConfig | None:
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
block_shape=self.weight_block_size,
)
return fp8_w8a8_moe_quant_config(
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.weight_block_size,