diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 5410bad1..39d02a64 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -252,15 +252,16 @@ def stage_activation(x_bf16, input_global_scale=None): """Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales. Two-level quantization matching the NVFP4 weight format: - 1. Per-tensor global scale: amax / (6.0 * 448.0) [dynamic] OR checkpoint input_scale [static] + 1. Per-tensor global scale: amax / (6.0 * 448.0) [default] or provided 2. Per-block (16 values) absmax scaling on the normalized values Args: x_bf16: BF16 activation tensor - input_global_scale: If provided, use this checkpoint-derived scale instead of - computing dynamically. The checkpoint's input_scale was used during weight - quantization — using the same scale at runtime ensures the quantized weights - are rescaled correctly. If None, compute from data (amax / (6.0 * 448.0)). + input_global_scale: If provided, use this as the activation global scale + instead of computing dynamically. WARNING: this is the amax/(6*448) + normalization scale, NOT the checkpoint's input_scale (which is a + different quantity used for alpha computation). Pass None to compute + dynamically from data. Returns (x_fp4, x_sf, input_global_scale) where: x_fp4: packed E2M1 nibbles @@ -332,14 +333,6 @@ def nvfp4_mega_moe_full( x_sf = symm_buffer.x_sf[:num_tokens] l1_global_scale = symm_buffer.input_global_scale - # Use checkpoint input_scales for alpha computation if available - # The checkpoint input_scale was used during weight calibration. - # alpha = input_scale * weight_scale_2 (NOT dynamic_scale * weight_scale_2) - if l1_input_scale is not None: - l1_igs = float(l1_input_scale[0]) # same for all experts - else: - l1_igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale - # Diagnostic: check FP4 quantization quality by dequantizing and comparing if not getattr(nvfp4_mega_moe_full, '_quant_diag', False): nvfp4_mega_moe_full._quant_diag = True @@ -398,8 +391,7 @@ def nvfp4_mega_moe_full( return # Ensure alpha is a plain Python float for the base activation global scale - # Use checkpoint input_scale if available (from weight calibration) - l1_alpha = l1_igs + l1_alpha = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale # Shape consistency asserts assert slot_expert_local.ndim == 1 @@ -503,15 +495,13 @@ def nvfp4_mega_moe_full( activated = activated.clamp(max=activation_clamp) # Step 4: Quantize activated slots → FP4 - # Use checkpoint input_scale for L2 (w2/down_proj) if available - l2_igs = float(l2_input_scale[0]) if l2_input_scale is not None else None - l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated, input_global_scale=l2_igs) + l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated) # Pre-L2 shape asserts assert activated.shape[0] == num_slots assert l1_fp4.shape[0] == num_slots assert l1_sf_out.shape[0] == num_slots - l2_alpha = l2_igs if l2_igs is not None else (float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale) + l2_alpha = float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale if MEGA_MOE_DEBUG: _l1sf_f32 = l1_sf_out.to(torch.float32) diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 816c57f4..580987e0 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -552,12 +552,8 @@ class DeepseekV4MegaMoEExperts(nn.Module): num_tokens = hidden_states.shape[0] # Quantize activation using the kernel's PyTorch stage_activation - # Use the checkpoint's input_scale for L1 (w13) activation quantization. - # The checkpoint's input_scale was used during weight calibration — using - # the same scale at runtime ensures the quantized weights are rescaled correctly. - # Dynamic stage_activation computes amax/(6*448) which can be 10x+ off. - w13_input_scale = float(self._w13_input_scale[0]) # same for all experts - x_fp4, x_sf, input_global_scale = stage_activation(hidden_states, input_global_scale=w13_input_scale) + # Dynamic quantization: input_global_scale = amax / (6 * 448) + x_fp4, x_sf, input_global_scale = stage_activation(hidden_states) symm_buffer.x[:num_tokens].copy_(x_fp4) symm_buffer.x_sf[:num_tokens].copy_(x_sf) symm_buffer.input_global_scale = input_global_scale