[MoE Refactor][14/N] Clean Up FI Quant Config Smuggling (#31593)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -452,11 +452,14 @@ class FusedMoEQuantConfig:
|
||||
- a1_scale: Optional scale to be used for a1.
|
||||
- a2_scale: Optional scale to be used for a2.
|
||||
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
|
||||
per-channel scales for w1 (for W4A8 FP8).
|
||||
Optional per-channel scales for w1 (for W4A8 FP8).
|
||||
Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8).
|
||||
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
|
||||
per-channel scales for w2 (for W4A8 FP8).
|
||||
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
|
||||
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
|
||||
Optional per-channel scales for w2 (for W4A8 FP8).
|
||||
Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8).
|
||||
- a1_gscale: Optional global quantization scales for a1 (1.0 /a2_scale).
|
||||
- a2_gscale: Optional global quantization scales for a2 (1.0 /a2_scale).
|
||||
|
||||
- w1_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w2_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w1_zp: Optional w1 zero points for int4/int8 quantization.
|
||||
|
||||
@@ -165,10 +165,10 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
):
|
||||
# FP8 per-tensor path: use global alphas/scales; do not pass input_sf
|
||||
quant_scales = [
|
||||
self.g1_alphas,
|
||||
self.a2_gscale,
|
||||
self.g2_alphas,
|
||||
self.a1_gscale,
|
||||
self.g1_alphas, # w13_weight_scale * w13_input_scale
|
||||
self.a2_gscale, # 1.0 / w2_input_scale
|
||||
self.g2_alphas, # w2_weight_scale * w2_input_scale
|
||||
self.a1_scale,
|
||||
]
|
||||
|
||||
a1q_scale = None # not passing input_sf in fp8
|
||||
|
||||
@@ -184,13 +184,14 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
self._apply_router_weight_on_input(
|
||||
a1, topk_weights, topk_ids, apply_router_weight_on_input
|
||||
)
|
||||
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
|
||||
is_nvfp4 = quant_config.quant_dtype == "nvfp4"
|
||||
if not self.use_dp and is_nvfp4:
|
||||
return a1, None, None, topk_ids, topk_weights
|
||||
|
||||
if not self.use_deepseek_fp8_block_scale:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
@@ -222,7 +223,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
topk_weights, topk_ids, a1q = gathered
|
||||
a1q_scale = None
|
||||
|
||||
if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None:
|
||||
if is_nvfp4 and a1q_scale is not None:
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
get_flashinfer_moe_backend,
|
||||
register_moe_scaling_factors,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
swap_w13_to_w31,
|
||||
@@ -774,6 +775,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
|
||||
"activation function, but got {layer.activation}."
|
||||
)
|
||||
dynamic_per_token = (
|
||||
not self.block_quant and self.quant_config.activation_scheme != "static"
|
||||
)
|
||||
if self.flashinfer_moe_backend is not None and dynamic_per_token:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer FP8 MoE backend does not support dynamic per token "
|
||||
"activation quantization."
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -905,6 +914,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor | None,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
) -> None:
|
||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
assert self.block_quant
|
||||
@@ -949,11 +960,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.block_quant:
|
||||
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
|
||||
else:
|
||||
# TODO(rob): this function is a hack that renames the scaling
|
||||
# factors in the Module. This is a hack we should clean up.
|
||||
register_moe_scaling_factors(layer)
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
w13_weight_scale=w13_weight,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_weight_scale=w2_weight,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
elif self.fp8_backend == Fp8MoeBackend.AITER:
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||
w13_weight, w2_weight
|
||||
@@ -990,6 +1006,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
# Flashinfer TRTLLM does not use the modular kernel abstraction.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return
|
||||
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
assert self.moe_quant_config is not None
|
||||
self.use_inplace = True
|
||||
@@ -1087,7 +1107,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Shuffle weights into the runtime format.
|
||||
self._convert_weights_to_kernel_format(
|
||||
layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
|
||||
layer=layer,
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
w13_weight_scale=w13_weight_scale,
|
||||
w2_weight_scale=w2_weight_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
@@ -1182,6 +1208,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
# TRTLLM does not use Modular Kernel.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
# MARLIN uses mixed precision W8A16 config.
|
||||
if self.fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
return fp8_w8a16_moe_quant_config(
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
@@ -1189,11 +1220,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
a1_scale = layer.w13_input_scale
|
||||
a2_scale = layer.w2_input_scale
|
||||
|
||||
# Flashinfer CUTLASS per-tensor uses single dq scale
|
||||
# (alpha = w_scale * a_scale) and inverse a2 scale.
|
||||
if (
|
||||
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
and not self.block_quant
|
||||
):
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w1_scale,
|
||||
a1_scale,
|
||||
w2_scale,
|
||||
a2_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=(1.0 / a2_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
)
|
||||
|
||||
# All other backends use normal config.
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
|
||||
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
@@ -1414,7 +1472,13 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
|
||||
# Shuffle weights into the runtime format.
|
||||
self._convert_weights_to_kernel_format(
|
||||
layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale
|
||||
layer=layer,
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
w13_weight_scale=layer.w13_weight_scale,
|
||||
w2_weight_scale=layer.w2_weight_scale,
|
||||
w13_input_scale=None,
|
||||
w2_input_scale=None,
|
||||
)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
|
||||
@@ -50,7 +50,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
get_flashinfer_moe_backend,
|
||||
is_flashinfer_supporting_global_sf,
|
||||
register_moe_scaling_factors,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
swap_w13_to_w31,
|
||||
@@ -947,9 +948,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
if self.moe.is_act_and_mul:
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
|
||||
# NOTE: this adds some attributes used by the trtllm kernel,
|
||||
# which does not conform to the modular kernels abstraction (yet).
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||
register_moe_scaling_factors(layer)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
w13_weight_scale=layer.w13_weight_scale,
|
||||
w13_input_scale=layer.w13_input_scale,
|
||||
w2_weight_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
|
||||
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
|
||||
@@ -999,19 +1009,34 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
# TRTLLM does not use modular kernels
|
||||
return None
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g2_alphas=layer.output2_scales_scalar.squeeze(),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
layer.w13_weight_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
a1_gscale=(1.0 / layer.w13_input_scale),
|
||||
a2_gscale=(1.0 / layer.w2_input_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
)
|
||||
else:
|
||||
assert self.flashinfer_moe_backend is None
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@@ -103,6 +103,26 @@ def rotate_flashinfer_fp8_moe_weights(
|
||||
)
|
||||
|
||||
|
||||
def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> None:
|
||||
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale=w13_weight_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_weight_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
layer.output1_scales_gate_scalar = g1_alphas
|
||||
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
def apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -117,19 +137,14 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_scalar to be initialized"
|
||||
)
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_gate_scalar to be initialized"
|
||||
)
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output2_scales_scalar to be initialized"
|
||||
)
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
|
||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
)
|
||||
@@ -155,40 +170,16 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
)
|
||||
|
||||
|
||||
def get_moe_scaling_factors(
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
activation_scale: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale)
|
||||
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
||||
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
||||
def make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g1_alphas = (w13_scale * w13_input_scale).squeeze()
|
||||
g2_alphas = (w2_scale * w2_input_scale).squeeze()
|
||||
|
||||
return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar
|
||||
|
||||
|
||||
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors(
|
||||
layer.w13_input_scale,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_input_scale,
|
||||
layer.w2_weight_scale,
|
||||
)
|
||||
layer.register_parameter(
|
||||
"output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False)
|
||||
)
|
||||
layer.register_parameter(
|
||||
"output1_scales_gate_scalar",
|
||||
torch.nn.Parameter(output1_gate_scales, requires_grad=False),
|
||||
)
|
||||
layer.register_parameter(
|
||||
"output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False)
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_scale_inv",
|
||||
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False),
|
||||
)
|
||||
return g1_alphas, g2_alphas
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
|
||||
Reference in New Issue
Block a user