[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling (#31593)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-01-06 10:47:04 -05:00
committed by GitHub
parent d3e477c013
commit af8fd73051
7 changed files with 174 additions and 85 deletions

View File

@@ -452,11 +452,14 @@ class FusedMoEQuantConfig:
- a1_scale: Optional scale to be used for a1.
- a2_scale: Optional scale to be used for a2.
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
per-channel scales for w1 (for W4A8 FP8).
Optional per-channel scales for w1 (for W4A8 FP8).
Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8).
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
per-channel scales for w2 (for W4A8 FP8).
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
Optional per-channel scales for w2 (for W4A8 FP8).
Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8).
- a1_gscale: Optional global quantization scales for a1 (1.0 /a2_scale).
- a2_gscale: Optional global quantization scales for a2 (1.0 /a2_scale).
- w1_bias: Optional biases for w1 (GPT OSS Triton).
- w2_bias: Optional biases for w1 (GPT OSS Triton).
- w1_zp: Optional w1 zero points for int4/int8 quantization.

View File

@@ -165,10 +165,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
):
# FP8 per-tensor path: use global alphas/scales; do not pass input_sf
quant_scales = [
self.g1_alphas,
self.a2_gscale,
self.g2_alphas,
self.a1_gscale,
self.g1_alphas, # w13_weight_scale * w13_input_scale
self.a2_gscale, # 1.0 / w2_input_scale
self.g2_alphas, # w2_weight_scale * w2_input_scale
self.a1_scale,
]
a1q_scale = None # not passing input_sf in fp8

View File

@@ -184,13 +184,14 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input
)
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
is_nvfp4 = quant_config.quant_dtype == "nvfp4"
if not self.use_dp and is_nvfp4:
return a1, None, None, topk_ids, topk_weights
if not self.use_deepseek_fp8_block_scale:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
@@ -222,7 +223,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
topk_weights, topk_ids, a1q = gathered
a1q_scale = None
if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None:
if is_nvfp4 and a1q_scale is not None:
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights

View File

@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend,
register_moe_scaling_factors,
make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
@@ -774,6 +775,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
"activation function, but got {layer.activation}."
)
dynamic_per_token = (
not self.block_quant and self.quant_config.activation_scheme != "static"
)
if self.flashinfer_moe_backend is not None and dynamic_per_token:
raise NotImplementedError(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
)
def create_weights(
self,
@@ -905,6 +914,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
assert self.block_quant
@@ -949,11 +960,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant:
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
else:
# TODO(rob): this function is a hack that renames the scaling
# factors in the Module. This is a hack we should clean up.
register_moe_scaling_factors(layer)
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=w13_weight,
w13_input_scale=w13_input_scale,
w2_weight_scale=w2_weight,
w2_input_scale=w2_input_scale,
)
elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
@@ -990,6 +1006,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
AiterExperts,
)
# Flashinfer TRTLLM does not use the modular kernel abstraction.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.use_inplace = True
@@ -1087,7 +1107,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
# Setup modular kernel for TP case.
@@ -1182,6 +1208,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# TRTLLM does not use Modular Kernel.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
# MARLIN uses mixed precision W8A16 config.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
@@ -1189,11 +1220,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
block_shape=self.weight_block_size,
)
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if (
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
and not self.block_quant
):
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=(1.0 / a2_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=self.weight_block_size,
)
@@ -1414,7 +1472,13 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=layer.w13_weight_scale,
w2_weight_scale=layer.w2_weight_scale,
w13_input_scale=None,
w2_input_scale=None,
)
# Setup modular kernel for TP case.

View File

@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf,
register_moe_scaling_factors,
make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
@@ -947,9 +948,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if self.flashinfer_moe_backend is not None:
if self.moe.is_act_and_mul:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
# NOTE: this adds some attributes used by the trtllm kernel,
# which does not conform to the modular kernels abstraction (yet).
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_moe_scaling_factors(layer)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=layer.w13_weight_scale,
w13_input_scale=layer.w13_input_scale,
w2_weight_scale=layer.w2_weight_scale,
w2_input_scale=layer.w2_input_scale,
)
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
@@ -999,19 +1009,34 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
# TRTLLM does not use modular kernels
return None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
w2_scale=layer.w2_weight_scale,
g2_alphas=layer.output2_scales_scalar.squeeze(),
a1_scale=layer.w13_input_scale,
a1_gscale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a2_gscale=layer.w2_input_scale_inv,
per_act_token_quant=False,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a1_gscale=(1.0 / layer.w13_input_scale),
a2_gscale=(1.0 / layer.w2_input_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
else:
assert self.flashinfer_moe_backend is None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
def apply(
self,

View File

@@ -103,6 +103,26 @@ def rotate_flashinfer_fp8_moe_weights(
)
def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
w13_weight_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w13_scale=w13_weight_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_weight_scale,
w2_input_scale=w2_input_scale,
)
layer.w2_input_scale_inv = 1.0 / w2_input_scale
layer.output1_scales_gate_scalar = g1_alphas
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
layer.output2_scales_scalar = g2_alphas
def apply_flashinfer_per_tensor_scale_fp8(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
@@ -117,19 +137,14 @@ def apply_flashinfer_per_tensor_scale_fp8(
from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_scalar to be initialized"
)
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_gate_scalar to be initialized"
)
assert layer.output1_scales_scalar is not None, (
"Expected output2_scales_scalar to be initialized"
)
from vllm.model_executor.models.llama4 import Llama4MoE
assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
)
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
"FusedMoE flashinfer kernels are only supported for Llama4"
)
@@ -155,40 +170,16 @@ def apply_flashinfer_per_tensor_scale_fp8(
)
def get_moe_scaling_factors(
input_scale: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
activation_scale: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale)
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
output2_scales_scalar = activation_scale * gemm2_weights_scale
def make_fp8_moe_alpha_scales_for_fi(
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
g1_alphas = (w13_scale * w13_input_scale).squeeze()
g2_alphas = (w2_scale * w2_input_scale).squeeze()
return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors(
layer.w13_input_scale,
layer.w13_weight_scale,
layer.w2_input_scale,
layer.w2_weight_scale,
)
layer.register_parameter(
"output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False)
)
layer.register_parameter(
"output1_scales_gate_scalar",
torch.nn.Parameter(output1_gate_scales, requires_grad=False),
)
layer.register_parameter(
"output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False)
)
layer.register_parameter(
"w2_input_scale_inv",
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False),
)
return g1_alphas, g2_alphas
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(