From a7eae10ef4e6d1efdc045ab9b96f9672c44e465d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 23:57:08 +0000 Subject: [PATCH] fix: use checkpoint input_scale for activation quantization Critical fix: the checkpoint's input_scale was used during weight calibration but we were computing dynamic scale from data (amax/2688). This was 13x off from the checkpoint value. Changes: - stage_activation() accepts optional input_global_scale parameter - nvfp4_mega_moe_full() accepts l1_input_scale and l2_input_scale - vLLM patch preserves w13/w2_input_scale in finalize_weights - L1 activation uses checkpoint w13_input_scale for quantization - L2 activation uses checkpoint w2_input_scale for quantization - alpha = input_scale * weight_scale_2 (correct calibration contract) --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 35 +++++++++++++++++----- vllm/patches/deepseek_v4.py | 14 +++++++-- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 08eb88b4..5410bad1 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -248,13 +248,20 @@ def _quantize_to_e2m1(x_f32): return packed.to(torch.int8), sf -def stage_activation(x_bf16): +def stage_activation(x_bf16, input_global_scale=None): """Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales. Two-level quantization matching the NVFP4 weight format: - 1. Per-tensor global scale: amax / (6.0 * 448.0) + 1. Per-tensor global scale: amax / (6.0 * 448.0) [dynamic] OR checkpoint input_scale [static] 2. Per-block (16 values) absmax scaling on the normalized values + Args: + x_bf16: BF16 activation tensor + input_global_scale: If provided, use this checkpoint-derived scale instead of + computing dynamically. The checkpoint's input_scale was used during weight + quantization — using the same scale at runtime ensures the quantized weights + are rescaled correctly. If None, compute from data (amax / (6.0 * 448.0)). + Returns (x_fp4, x_sf, input_global_scale) where: x_fp4: packed E2M1 nibbles x_sf: UE4M3 block scales (NOT folded with global scale) @@ -262,8 +269,9 @@ def stage_activation(x_bf16): """ x_f32 = x_bf16.float() - x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8) - input_global_scale = x_amax / (6.0 * 448.0) + if input_global_scale is None: + x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8) + input_global_scale = x_amax / (6.0 * 448.0) x_normalized = x_f32 / input_global_scale @@ -279,6 +287,8 @@ def nvfp4_mega_moe_full( symm_buffer, # SymmBuffer from get_symm_buffer activation_clamp=None, # optional clamp value (unused in NVFP4) fast_math=False, # fast math flag (unused in NVFP4) + l1_input_scale=None, # (num_experts,) float32 — checkpoint input_scale for L1 (w13) + l2_input_scale=None, # (num_experts,) float32 — checkpoint input_scale for L2 (w2) ): """Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe. @@ -322,6 +332,14 @@ def nvfp4_mega_moe_full( x_sf = symm_buffer.x_sf[:num_tokens] l1_global_scale = symm_buffer.input_global_scale + # Use checkpoint input_scales for alpha computation if available + # The checkpoint input_scale was used during weight calibration. + # alpha = input_scale * weight_scale_2 (NOT dynamic_scale * weight_scale_2) + if l1_input_scale is not None: + l1_igs = float(l1_input_scale[0]) # same for all experts + else: + l1_igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale + # Diagnostic: check FP4 quantization quality by dequantizing and comparing if not getattr(nvfp4_mega_moe_full, '_quant_diag', False): nvfp4_mega_moe_full._quant_diag = True @@ -380,7 +398,8 @@ def nvfp4_mega_moe_full( return # Ensure alpha is a plain Python float for the base activation global scale - l1_alpha = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale + # Use checkpoint input_scale if available (from weight calibration) + l1_alpha = l1_igs # Shape consistency asserts assert slot_expert_local.ndim == 1 @@ -484,13 +503,15 @@ def nvfp4_mega_moe_full( activated = activated.clamp(max=activation_clamp) # Step 4: Quantize activated slots → FP4 - l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated) + # Use checkpoint input_scale for L2 (w2/down_proj) if available + l2_igs = float(l2_input_scale[0]) if l2_input_scale is not None else None + l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated, input_global_scale=l2_igs) # Pre-L2 shape asserts assert activated.shape[0] == num_slots assert l1_fp4.shape[0] == num_slots assert l1_sf_out.shape[0] == num_slots - l2_alpha = float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale + l2_alpha = l2_igs if l2_igs is not None else (float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale) if MEGA_MOE_DEBUG: _l1sf_f32 = l1_sf_out.to(torch.float32) diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 74a39f86..816c57f4 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -457,7 +457,9 @@ class DeepseekV4MegaMoEExperts(nn.Module): ) ) - # Drop the original loader-side parameters + # Drop the original loader-side parameters (preserve input_scales) + self._w13_input_scale = self.w13_input_scale.data.clone() + self._w2_input_scale = self.w2_input_scale.data.clone() self.w13_weight = None self.w13_weight_scale = None self.w13_weight_scale_2 = None @@ -550,8 +552,12 @@ class DeepseekV4MegaMoEExperts(nn.Module): num_tokens = hidden_states.shape[0] # Quantize activation using the kernel's PyTorch stage_activation - # (same code path the kernel uses for L1→L2 requantization). - x_fp4, x_sf, input_global_scale = stage_activation(hidden_states) + # Use the checkpoint's input_scale for L1 (w13) activation quantization. + # The checkpoint's input_scale was used during weight calibration — using + # the same scale at runtime ensures the quantized weights are rescaled correctly. + # Dynamic stage_activation computes amax/(6*448) which can be 10x+ off. + w13_input_scale = float(self._w13_input_scale[0]) # same for all experts + x_fp4, x_sf, input_global_scale = stage_activation(hidden_states, input_global_scale=w13_input_scale) symm_buffer.x[:num_tokens].copy_(x_fp4) symm_buffer.x_sf[:num_tokens].copy_(x_sf) symm_buffer.input_global_scale = input_global_scale @@ -572,6 +578,8 @@ class DeepseekV4MegaMoEExperts(nn.Module): symm_buffer, activation_clamp=activation_clamp, fast_math=fast_math, + l1_input_scale=self._w13_input_scale, + l2_input_scale=self._w2_input_scale, ) if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1': torch.cuda.synchronize()