[MoE Refactor][12/N] Marlin Fp8 MoE Pure Function (#31499)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -912,6 +912,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||
w13_weight, w2_weight
|
||||
)
|
||||
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
(
|
||||
workspace,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
) = prepare_moe_fp8_layer_for_marlin(
|
||||
layer,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
layer.workspace = workspace
|
||||
|
||||
elif self.fp8_backend in [
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
@@ -937,17 +954,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
|
||||
|
||||
# TODO(rob): we do this after replace_parameter() because
|
||||
# prepare_moe_fp8_layer_for_marlin uses on the layer's params
|
||||
# directly. We will refactor this in a follow up PR.
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def _setup_kernel(self, layer: Module) -> None:
|
||||
"""Setup Modular Kernel for TP Case"""
|
||||
# NOTE(rob): this is a WIP refactor. We are first migrating
|
||||
@@ -1194,20 +1200,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.weight_block_size,
|
||||
|
||||
Reference in New Issue
Block a user