[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling (#31593)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -103,6 +103,26 @@ def rotate_flashinfer_fp8_moe_weights(
|
||||
)
|
||||
|
||||
|
||||
def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> None:
|
||||
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale=w13_weight_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_weight_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
layer.output1_scales_gate_scalar = g1_alphas
|
||||
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
def apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -117,19 +137,14 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_scalar to be initialized"
|
||||
)
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_gate_scalar to be initialized"
|
||||
)
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output2_scales_scalar to be initialized"
|
||||
)
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
|
||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
)
|
||||
@@ -155,40 +170,16 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
)
|
||||
|
||||
|
||||
def get_moe_scaling_factors(
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
activation_scale: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale)
|
||||
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
||||
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
||||
def make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g1_alphas = (w13_scale * w13_input_scale).squeeze()
|
||||
g2_alphas = (w2_scale * w2_input_scale).squeeze()
|
||||
|
||||
return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar
|
||||
|
||||
|
||||
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors(
|
||||
layer.w13_input_scale,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_input_scale,
|
||||
layer.w2_weight_scale,
|
||||
)
|
||||
layer.register_parameter(
|
||||
"output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False)
|
||||
)
|
||||
layer.register_parameter(
|
||||
"output1_scales_gate_scalar",
|
||||
torch.nn.Parameter(output1_gate_scales, requires_grad=False),
|
||||
)
|
||||
layer.register_parameter(
|
||||
"output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False)
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_scale_inv",
|
||||
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False),
|
||||
)
|
||||
return g1_alphas, g2_alphas
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
|
||||
Reference in New Issue
Block a user