From d547da294864e2e896233dcec02bfca97f8cb3a8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 03:27:47 +0000 Subject: [PATCH] stage_activation: add per-tensor global scale matching NVFP4 spec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Without a global scale, block scales (block_max / 6.0) could exceed UE4M3 max (448.0) for large activations, causing saturation and garbage MoE outputs. The degeneration pattern (positions 1-5 OK, then constant spaces) is consistent with UE4M3 overflow: first few tokens have small activations that fit, but once SiLU(mul(gate, up)) produces larger values, block scales overflow and the GEMM produces zeros/garbage. Fix: compute input_global_scale = amax / (6.0 * 448.0), normalize before block quantization, then fold global scale back into block scales (same as weight_transform.py folds weight_scale_2). This ensures block scales are always ≤ 448.0 in UE4M3 range. --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 40 ++++++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 148db3ed..9eaa5e9b 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -220,15 +220,41 @@ def _quantize_to_e2m1(x_f32): def stage_activation(x_bf16): """Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales. - Proper E2M1 quantization: - - Per-block (16 values) absmax scaling - - Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} - - Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index) - - Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn) + Two-level quantization matching the NVFP4 weight format: + 1. Per-tensor global scale: amax / (6.0 * 448.0) + Normalizes the activation so that block scales fit in UE4M3 range. + 2. Per-block (16 values) absmax scaling on the normalized values + Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} + Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index) + Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn) + + The global scale is folded into the block scales (same as weight_transform + folds weight_scale_2 into weight_scale), so the GEMM API is unchanged: + dequant = e2m1_magnitude * block_scale + where block_scale = (block_max / 6.0) * global_scale """ x_f32 = x_bf16.float() - x_fp4, x_sf = _quantize_to_e2m1(x_f32) - return x_fp4, x_sf + + # Per-tensor global scale (same role as weight_scale_2) + # NVFP4 spec: global_scale = amax / (6.0 * 448.0) + # This ensures the largest block scale after normalization is ~448.0, + # which fits exactly in UE4M3 max (448.0 for E4M3). + x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8) + input_global_scale = x_amax / (6.0 * 448.0) + + # Normalize by global scale before block quantization. + # After this, values are in a range where block_max / 6.0 ≤ 448.0, + # so block scales fit in UE4M3 without saturation. + x_normalized = x_f32 / input_global_scale + + x_fp4, x_sf = _quantize_to_e2m1(x_normalized) + + # Fold global scale into block scales: sf_folded = sf * global_scale + # (same as _fold_global_scale in weight_transform.py) + sf_folded = x_sf.to(torch.float32) * input_global_scale + x_sf_folded = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) + + return x_fp4, x_sf_folded def nvfp4_mega_moe_full(