revert: don't use checkpoint input_scale for activation normalization
Using checkpoint input_scale as the normalization scale saturates FP4 values (all block scales = 448). The input_scale is a calibration constant, NOT the amax/(6*448) normalization scale. Reverted to dynamic amax/(6*448) for activation quantization. The correct use of checkpoint input_scale is still under investigation. Preserved: _w13_input_scale and _w2_input_scale in finalize_weights for future use once we understand the correct alpha contract.
This commit is contained in:
@@ -252,15 +252,16 @@ def stage_activation(x_bf16, input_global_scale=None):
|
||||
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
|
||||
|
||||
Two-level quantization matching the NVFP4 weight format:
|
||||
1. Per-tensor global scale: amax / (6.0 * 448.0) [dynamic] OR checkpoint input_scale [static]
|
||||
1. Per-tensor global scale: amax / (6.0 * 448.0) [default] or provided
|
||||
2. Per-block (16 values) absmax scaling on the normalized values
|
||||
|
||||
Args:
|
||||
x_bf16: BF16 activation tensor
|
||||
input_global_scale: If provided, use this checkpoint-derived scale instead of
|
||||
computing dynamically. The checkpoint's input_scale was used during weight
|
||||
quantization — using the same scale at runtime ensures the quantized weights
|
||||
are rescaled correctly. If None, compute from data (amax / (6.0 * 448.0)).
|
||||
input_global_scale: If provided, use this as the activation global scale
|
||||
instead of computing dynamically. WARNING: this is the amax/(6*448)
|
||||
normalization scale, NOT the checkpoint's input_scale (which is a
|
||||
different quantity used for alpha computation). Pass None to compute
|
||||
dynamically from data.
|
||||
|
||||
Returns (x_fp4, x_sf, input_global_scale) where:
|
||||
x_fp4: packed E2M1 nibbles
|
||||
@@ -332,14 +333,6 @@ def nvfp4_mega_moe_full(
|
||||
x_sf = symm_buffer.x_sf[:num_tokens]
|
||||
l1_global_scale = symm_buffer.input_global_scale
|
||||
|
||||
# Use checkpoint input_scales for alpha computation if available
|
||||
# The checkpoint input_scale was used during weight calibration.
|
||||
# alpha = input_scale * weight_scale_2 (NOT dynamic_scale * weight_scale_2)
|
||||
if l1_input_scale is not None:
|
||||
l1_igs = float(l1_input_scale[0]) # same for all experts
|
||||
else:
|
||||
l1_igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
|
||||
|
||||
# Diagnostic: check FP4 quantization quality by dequantizing and comparing
|
||||
if not getattr(nvfp4_mega_moe_full, '_quant_diag', False):
|
||||
nvfp4_mega_moe_full._quant_diag = True
|
||||
@@ -398,8 +391,7 @@ def nvfp4_mega_moe_full(
|
||||
return
|
||||
|
||||
# Ensure alpha is a plain Python float for the base activation global scale
|
||||
# Use checkpoint input_scale if available (from weight calibration)
|
||||
l1_alpha = l1_igs
|
||||
l1_alpha = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
|
||||
|
||||
# Shape consistency asserts
|
||||
assert slot_expert_local.ndim == 1
|
||||
@@ -503,15 +495,13 @@ def nvfp4_mega_moe_full(
|
||||
activated = activated.clamp(max=activation_clamp)
|
||||
|
||||
# Step 4: Quantize activated slots → FP4
|
||||
# Use checkpoint input_scale for L2 (w2/down_proj) if available
|
||||
l2_igs = float(l2_input_scale[0]) if l2_input_scale is not None else None
|
||||
l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated, input_global_scale=l2_igs)
|
||||
l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated)
|
||||
|
||||
# Pre-L2 shape asserts
|
||||
assert activated.shape[0] == num_slots
|
||||
assert l1_fp4.shape[0] == num_slots
|
||||
assert l1_sf_out.shape[0] == num_slots
|
||||
l2_alpha = l2_igs if l2_igs is not None else (float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale)
|
||||
l2_alpha = float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
_l1sf_f32 = l1_sf_out.to(torch.float32)
|
||||
|
||||
@@ -552,12 +552,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# Quantize activation using the kernel's PyTorch stage_activation
|
||||
# Use the checkpoint's input_scale for L1 (w13) activation quantization.
|
||||
# The checkpoint's input_scale was used during weight calibration — using
|
||||
# the same scale at runtime ensures the quantized weights are rescaled correctly.
|
||||
# Dynamic stage_activation computes amax/(6*448) which can be 10x+ off.
|
||||
w13_input_scale = float(self._w13_input_scale[0]) # same for all experts
|
||||
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states, input_global_scale=w13_input_scale)
|
||||
# Dynamic quantization: input_global_scale = amax / (6 * 448)
|
||||
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states)
|
||||
symm_buffer.x[:num_tokens].copy_(x_fp4)
|
||||
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
|
||||
symm_buffer.input_global_scale = input_global_scale
|
||||
|
||||
Reference in New Issue
Block a user