[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling (#31593)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-01-06 10:47:04 -05:00
committed by GitHub
parent d3e477c013
commit af8fd73051
7 changed files with 174 additions and 85 deletions

View File

@@ -103,6 +103,26 @@ def rotate_flashinfer_fp8_moe_weights(
)
def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
w13_weight_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w13_scale=w13_weight_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_weight_scale,
w2_input_scale=w2_input_scale,
)
layer.w2_input_scale_inv = 1.0 / w2_input_scale
layer.output1_scales_gate_scalar = g1_alphas
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
layer.output2_scales_scalar = g2_alphas
def apply_flashinfer_per_tensor_scale_fp8(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
@@ -117,19 +137,14 @@ def apply_flashinfer_per_tensor_scale_fp8(
from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_scalar to be initialized"
)
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_gate_scalar to be initialized"
)
assert layer.output1_scales_scalar is not None, (
"Expected output2_scales_scalar to be initialized"
)
from vllm.model_executor.models.llama4 import Llama4MoE
assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
)
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
"FusedMoE flashinfer kernels are only supported for Llama4"
)
@@ -155,40 +170,16 @@ def apply_flashinfer_per_tensor_scale_fp8(
)
def get_moe_scaling_factors(
input_scale: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
activation_scale: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale)
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
output2_scales_scalar = activation_scale * gemm2_weights_scale
def make_fp8_moe_alpha_scales_for_fi(
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
g1_alphas = (w13_scale * w13_input_scale).squeeze()
g2_alphas = (w2_scale * w2_input_scale).squeeze()
return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors(
layer.w13_input_scale,
layer.w13_weight_scale,
layer.w2_input_scale,
layer.w2_weight_scale,
)
layer.register_parameter(
"output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False)
)
layer.register_parameter(
"output1_scales_gate_scalar",
torch.nn.Parameter(output1_gate_scales, requires_grad=False),
)
layer.register_parameter(
"output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False)
)
layer.register_parameter(
"w2_input_scale_inv",
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False),
)
return g1_alphas, g2_alphas
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(