[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling (#31593)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-01-06 10:47:04 -05:00
committed by GitHub
parent d3e477c013
commit af8fd73051
7 changed files with 174 additions and 85 deletions

View File

@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend,
register_moe_scaling_factors,
make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
@@ -774,6 +775,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
"activation function, but got {layer.activation}."
)
dynamic_per_token = (
not self.block_quant and self.quant_config.activation_scheme != "static"
)
if self.flashinfer_moe_backend is not None and dynamic_per_token:
raise NotImplementedError(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
)
def create_weights(
self,
@@ -905,6 +914,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
assert self.block_quant
@@ -949,11 +960,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant:
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
else:
# TODO(rob): this function is a hack that renames the scaling
# factors in the Module. This is a hack we should clean up.
register_moe_scaling_factors(layer)
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=w13_weight,
w13_input_scale=w13_input_scale,
w2_weight_scale=w2_weight,
w2_input_scale=w2_input_scale,
)
elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
@@ -990,6 +1006,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
AiterExperts,
)
# Flashinfer TRTLLM does not use the modular kernel abstraction.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.use_inplace = True
@@ -1087,7 +1107,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
# Setup modular kernel for TP case.
@@ -1182,6 +1208,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# TRTLLM does not use Modular Kernel.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN uses mixed precision W8A16 config.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
@@ -1189,11 +1220,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
block_shape=self.weight_block_size,
)
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if (
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
and not self.block_quant
):
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=(1.0 / a2_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=self.weight_block_size,
)
@@ -1414,7 +1472,13 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=layer.w13_weight_scale,
w2_weight_scale=layer.w2_weight_scale,
w13_input_scale=None,
w2_input_scale=None,
)
# Setup modular kernel for TP case.