stage_activation: add per-tensor global scale matching NVFP4 spec

Without a global scale, block scales (block_max / 6.0) could exceed
UE4M3 max (448.0) for large activations, causing saturation and garbage
MoE outputs. The degeneration pattern (positions 1-5 OK, then constant
spaces) is consistent with UE4M3 overflow: first few tokens have small
activations that fit, but once SiLU(mul(gate, up)) produces larger
values, block scales overflow and the GEMM produces zeros/garbage.

Fix: compute input_global_scale = amax / (6.0 * 448.0), normalize
before block quantization, then fold global scale back into block
scales (same as weight_transform.py folds weight_scale_2). This
ensures block scales are always ≤ 448.0 in UE4M3 range.
This commit is contained in:
2026-05-15 03:27:47 +00:00
parent 108ff07569
commit d547da2948

View File

@@ -220,15 +220,41 @@ def _quantize_to_e2m1(x_f32):
def stage_activation(x_bf16):
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
Proper E2M1 quantization:
- Per-block (16 values) absmax scaling
- Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}
- Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index)
- Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn)
Two-level quantization matching the NVFP4 weight format:
1. Per-tensor global scale: amax / (6.0 * 448.0)
Normalizes the activation so that block scales fit in UE4M3 range.
2. Per-block (16 values) absmax scaling on the normalized values
Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}
Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index)
Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn)
The global scale is folded into the block scales (same as weight_transform
folds weight_scale_2 into weight_scale), so the GEMM API is unchanged:
dequant = e2m1_magnitude * block_scale
where block_scale = (block_max / 6.0) * global_scale
"""
x_f32 = x_bf16.float()
x_fp4, x_sf = _quantize_to_e2m1(x_f32)
return x_fp4, x_sf
# Per-tensor global scale (same role as weight_scale_2)
# NVFP4 spec: global_scale = amax / (6.0 * 448.0)
# This ensures the largest block scale after normalization is ~448.0,
# which fits exactly in UE4M3 max (448.0 for E4M3).
x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8)
input_global_scale = x_amax / (6.0 * 448.0)
# Normalize by global scale before block quantization.
# After this, values are in a range where block_max / 6.0 ≤ 448.0,
# so block scales fit in UE4M3 without saturation.
x_normalized = x_f32 / input_global_scale
x_fp4, x_sf = _quantize_to_e2m1(x_normalized)
# Fold global scale into block scales: sf_folded = sf * global_scale
# (same as _fold_global_scale in weight_transform.py)
sf_folded = x_sf.to(torch.float32) * input_global_scale
x_sf_folded = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
return x_fp4, x_sf_folded
def nvfp4_mega_moe_full(