fix: use checkpoint input_scale for activation quantization

Critical fix: the checkpoint's input_scale was used during weight
calibration but we were computing dynamic scale from data (amax/2688).
This was 13x off from the checkpoint value.

Changes:
- stage_activation() accepts optional input_global_scale parameter
- nvfp4_mega_moe_full() accepts l1_input_scale and l2_input_scale
- vLLM patch preserves w13/w2_input_scale in finalize_weights
- L1 activation uses checkpoint w13_input_scale for quantization
- L2 activation uses checkpoint w2_input_scale for quantization
- alpha = input_scale * weight_scale_2 (correct calibration contract)
This commit is contained in:
2026-05-15 23:57:08 +00:00
parent af50e98fe9
commit a7eae10ef4
2 changed files with 39 additions and 10 deletions

View File

@@ -248,13 +248,20 @@ def _quantize_to_e2m1(x_f32):
return packed.to(torch.int8), sf
def stage_activation(x_bf16):
def stage_activation(x_bf16, input_global_scale=None):
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
Two-level quantization matching the NVFP4 weight format:
1. Per-tensor global scale: amax / (6.0 * 448.0)
1. Per-tensor global scale: amax / (6.0 * 448.0) [dynamic] OR checkpoint input_scale [static]
2. Per-block (16 values) absmax scaling on the normalized values
Args:
x_bf16: BF16 activation tensor
input_global_scale: If provided, use this checkpoint-derived scale instead of
computing dynamically. The checkpoint's input_scale was used during weight
quantization — using the same scale at runtime ensures the quantized weights
are rescaled correctly. If None, compute from data (amax / (6.0 * 448.0)).
Returns (x_fp4, x_sf, input_global_scale) where:
x_fp4: packed E2M1 nibbles
x_sf: UE4M3 block scales (NOT folded with global scale)
@@ -262,8 +269,9 @@ def stage_activation(x_bf16):
"""
x_f32 = x_bf16.float()
x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8)
input_global_scale = x_amax / (6.0 * 448.0)
if input_global_scale is None:
x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8)
input_global_scale = x_amax / (6.0 * 448.0)
x_normalized = x_f32 / input_global_scale
@@ -279,6 +287,8 @@ def nvfp4_mega_moe_full(
symm_buffer, # SymmBuffer from get_symm_buffer
activation_clamp=None, # optional clamp value (unused in NVFP4)
fast_math=False, # fast math flag (unused in NVFP4)
l1_input_scale=None, # (num_experts,) float32 — checkpoint input_scale for L1 (w13)
l2_input_scale=None, # (num_experts,) float32 — checkpoint input_scale for L2 (w2)
):
"""Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe.
@@ -322,6 +332,14 @@ def nvfp4_mega_moe_full(
x_sf = symm_buffer.x_sf[:num_tokens]
l1_global_scale = symm_buffer.input_global_scale
# Use checkpoint input_scales for alpha computation if available
# The checkpoint input_scale was used during weight calibration.
# alpha = input_scale * weight_scale_2 (NOT dynamic_scale * weight_scale_2)
if l1_input_scale is not None:
l1_igs = float(l1_input_scale[0]) # same for all experts
else:
l1_igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
# Diagnostic: check FP4 quantization quality by dequantizing and comparing
if not getattr(nvfp4_mega_moe_full, '_quant_diag', False):
nvfp4_mega_moe_full._quant_diag = True
@@ -380,7 +398,8 @@ def nvfp4_mega_moe_full(
return
# Ensure alpha is a plain Python float for the base activation global scale
l1_alpha = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
# Use checkpoint input_scale if available (from weight calibration)
l1_alpha = l1_igs
# Shape consistency asserts
assert slot_expert_local.ndim == 1
@@ -484,13 +503,15 @@ def nvfp4_mega_moe_full(
activated = activated.clamp(max=activation_clamp)
# Step 4: Quantize activated slots → FP4
l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated)
# Use checkpoint input_scale for L2 (w2/down_proj) if available
l2_igs = float(l2_input_scale[0]) if l2_input_scale is not None else None
l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated, input_global_scale=l2_igs)
# Pre-L2 shape asserts
assert activated.shape[0] == num_slots
assert l1_fp4.shape[0] == num_slots
assert l1_sf_out.shape[0] == num_slots
l2_alpha = float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale
l2_alpha = l2_igs if l2_igs is not None else (float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale)
if MEGA_MOE_DEBUG:
_l1sf_f32 = l1_sf_out.to(torch.float32)

View File

@@ -457,7 +457,9 @@ class DeepseekV4MegaMoEExperts(nn.Module):
)
)
# Drop the original loader-side parameters
# Drop the original loader-side parameters (preserve input_scales)
self._w13_input_scale = self.w13_input_scale.data.clone()
self._w2_input_scale = self.w2_input_scale.data.clone()
self.w13_weight = None
self.w13_weight_scale = None
self.w13_weight_scale_2 = None
@@ -550,8 +552,12 @@ class DeepseekV4MegaMoEExperts(nn.Module):
num_tokens = hidden_states.shape[0]
# Quantize activation using the kernel's PyTorch stage_activation
# (same code path the kernel uses for L1→L2 requantization).
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states)
# Use the checkpoint's input_scale for L1 (w13) activation quantization.
# The checkpoint's input_scale was used during weight calibration — using
# the same scale at runtime ensures the quantized weights are rescaled correctly.
# Dynamic stage_activation computes amax/(6*448) which can be 10x+ off.
w13_input_scale = float(self._w13_input_scale[0]) # same for all experts
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states, input_global_scale=w13_input_scale)
symm_buffer.x[:num_tokens].copy_(x_fp4)
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
symm_buffer.input_global_scale = input_global_scale
@@ -572,6 +578,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
symm_buffer,
activation_clamp=activation_clamp,
fast_math=fast_math,
l1_input_scale=self._w13_input_scale,
l2_input_scale=self._w2_input_scale,
)
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
torch.cuda.synchronize()