From cdcffafef870cb8fcc80640b2f4ce1b39464dee5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elvir=20Crn=C4=8Devi=C4=87?= Date: Mon, 16 Mar 2026 23:03:54 +0100 Subject: [PATCH] Fix eplb nvfp4 experts hook (#37217) Signed-off-by: Elvir Crncevic Signed-off-by: Elvir Crncevic Co-authored-by: Tyler Michael Smith Co-authored-by: Claude Opus 4.6 (cherry picked from commit fd4d96302a2999a8d773b1b331951d232e3f5e05) --- .../layers/fused_moe/cutlass_moe.py | 7 ++++++ .../fused_moe/experts/trtllm_nvfp4_moe.py | 23 +++++++++++++++---- .../fused_moe/flashinfer_cutedsl_moe.py | 4 ++++ .../fused_moe/flashinfer_cutlass_moe.py | 5 ++++ vllm/model_executor/layers/fused_moe/layer.py | 18 +++++++++------ .../layers/fused_moe/modular_kernel.py | 3 +++ .../layers/fused_moe/oracle/nvfp4.py | 10 ++++---- .../compressed_tensors_moe.py | 1 + .../layers/quantization/modelopt.py | 1 + .../quantization/utils/flashinfer_fp4_moe.py | 10 -------- 10 files changed, 57 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 51a97e0a2..534cab1b8 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -659,6 +659,13 @@ def run_cutlass_moe_fp4( class CutlassExpertsFp4(mk.FusedMoEExpertsModular): """CUTLASS FP4 fused MoE expert implementation.""" + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fuse activation scales into w_scale_2 in-place so that + # g1/g2_alphas (which reference the same tensor) stay in sync + # when EPLB rearranges the parameter. + layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) + layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) + @property def expects_unquantized_inputs(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index 174c581b3..87b1eb9fd 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -56,10 +56,25 @@ class TrtLlmNvFp4ExpertsBase: # g1_scale_c = a13_scale * w13_scale_2 / a2_scale self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale else: - self.g1_scale_c = ( - torch.ones_like(self.quant_config.a1_gscale) - * self.quant_config.a2_gscale - ) + self.g1_scale_c = self.quant_config.a2_gscale.clone() + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) + layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) + # Recompute g1_scale_c since g1_alphas was just fused in-place. + # Register as a layer parameter so EPLB rearranges it alongside + # other expert weights. + assert self.quant_config.g1_alphas is not None + assert self.quant_config.a2_gscale is not None + if self.moe_config.is_act_and_mul: + g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale + else: + g1_scale_c = self.quant_config.a2_gscale.clone() + layer.register_parameter( + "g1_scale_c", + torch.nn.Parameter(g1_scale_c, requires_grad=False), + ) + self.g1_scale_c = layer.g1_scale_c @staticmethod def _supports_current_device() -> bool: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index fb8a18ef3..5805a4dd5 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -49,6 +49,10 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): ) self.out_dtype = moe_config.in_dtype + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) + layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) + @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index e58d52eee..91f7a83f6 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -61,6 +61,11 @@ def is_valid_flashinfer_cutlass_fused_moe( class FlashInferExperts(mk.FusedMoEExpertsModular): + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.quant_config.use_nvfp4_w4a4: + layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) + layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) + def __init__( self, moe_config: mk.FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7135cbbd2..75283b9bb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1421,19 +1421,23 @@ class FusedMoE(CustomOp): weights = list(self.named_parameters()) weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights] + # `w13_input_scale` and `w2_input_scale` are global per-tensor + # activation scales shared across all experts (e.g. NVFP4). + # They are broadcast views (stride 0) from .expand() and are + # not actual expert weights, so exclude them from EPLB. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", + "w13_input_scale", + "w2_input_scale", + } + assert all( weight.is_contiguous() for name, weight in weights if not (name.startswith("_shared_experts.") or name.startswith("_gate.")) + and name not in NON_EXPERT_WEIGHTS ) - # Filter out the non-expert weights. - # `e_score_correction_bias` is a bias for each logical expert, - # with shape (num_logical_experts,), not an expert weight. - NON_EXPERT_WEIGHTS = { - "e_score_correction_bias", - } - return [ weight.view(self.local_num_experts, -1) for name, weight in weights diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 7100c87c9..a6b498834 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -489,6 +489,9 @@ class FusedMoEExperts(ABC): self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # noqa: B027 + pass + @staticmethod def is_monolithic() -> bool: raise NotImplementedError("Implemented by subclasses.") diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index b06cf49cf..8a224cb39 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -374,11 +374,13 @@ def make_nvfp4_moe_quant_config( w2_scale=w2_scale, ) - g1_alphas = a13_scale * w13_scale_2 - g2_alphas = a2_scale * w2_scale_2 + # Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas. + # The expert's process_weights_after_loading will fuse activation + # scales in-place. Since the quant config references the same tensor + # as the registered parameter, EPLB rearrangement stays in sync. return nvfp4_moe_quant_config( - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, + g1_alphas=w13_scale_2, + g2_alphas=w2_scale_2, a1_gscale=(1.0 / a13_scale), a2_gscale=(1.0 / a2_scale), w1_scale=w13_scale, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f35a4c0b9..29115fbbc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -570,6 +570,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): shared_experts=layer.shared_experts, routing_tables=layer._maybe_init_expert_routing_tables(), ) + self.moe_kernel.fused_experts.process_weights_after_loading(layer) def maybe_make_prepare_finalize( self, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 977612313..640580da6 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1394,6 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): shared_experts=layer.shared_experts, routing_tables=layer._maybe_init_expert_routing_tables(), ) + self.moe_kernel.fused_experts.process_weights_after_loading(layer) def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: return make_nvfp4_moe_quant_config( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 42677a592..66300ceae 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -267,16 +267,6 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( num_experts=w13.size(0), is_gated_activation=is_gated, ) - - # We do not need to make this a parameter, because - # it is not used during the weight (re)-loading process. - if is_gated: - layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale - else: - layer.g1_scale_c = torch.ones_like(a13_scale) / a2_scale - layer.a1_gscale = 1.0 / a13_scale - layer.g1_alphas = a13_scale * w13_scale_2 - layer.g2_alphas = a2_scale * w2_scale_2 else: # Swizzle the block scales for other FI NVFP4 MoE kernels. w13_scale = swizzle_blockscale(w13_scale)