stage_activation: add per-tensor global scale matching NVFP4 spec
Without a global scale, block scales (block_max / 6.0) could exceed UE4M3 max (448.0) for large activations, causing saturation and garbage MoE outputs. The degeneration pattern (positions 1-5 OK, then constant spaces) is consistent with UE4M3 overflow: first few tokens have small activations that fit, but once SiLU(mul(gate, up)) produces larger values, block scales overflow and the GEMM produces zeros/garbage. Fix: compute input_global_scale = amax / (6.0 * 448.0), normalize before block quantization, then fold global scale back into block scales (same as weight_transform.py folds weight_scale_2). This ensures block scales are always ≤ 448.0 in UE4M3 range.
This commit is contained in:
@@ -220,15 +220,41 @@ def _quantize_to_e2m1(x_f32):
|
||||
def stage_activation(x_bf16):
|
||||
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
|
||||
|
||||
Proper E2M1 quantization:
|
||||
- Per-block (16 values) absmax scaling
|
||||
- Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}
|
||||
- Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index)
|
||||
- Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn)
|
||||
Two-level quantization matching the NVFP4 weight format:
|
||||
1. Per-tensor global scale: amax / (6.0 * 448.0)
|
||||
Normalizes the activation so that block scales fit in UE4M3 range.
|
||||
2. Per-block (16 values) absmax scaling on the normalized values
|
||||
Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}
|
||||
Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index)
|
||||
Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn)
|
||||
|
||||
The global scale is folded into the block scales (same as weight_transform
|
||||
folds weight_scale_2 into weight_scale), so the GEMM API is unchanged:
|
||||
dequant = e2m1_magnitude * block_scale
|
||||
where block_scale = (block_max / 6.0) * global_scale
|
||||
"""
|
||||
x_f32 = x_bf16.float()
|
||||
x_fp4, x_sf = _quantize_to_e2m1(x_f32)
|
||||
return x_fp4, x_sf
|
||||
|
||||
# Per-tensor global scale (same role as weight_scale_2)
|
||||
# NVFP4 spec: global_scale = amax / (6.0 * 448.0)
|
||||
# This ensures the largest block scale after normalization is ~448.0,
|
||||
# which fits exactly in UE4M3 max (448.0 for E4M3).
|
||||
x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8)
|
||||
input_global_scale = x_amax / (6.0 * 448.0)
|
||||
|
||||
# Normalize by global scale before block quantization.
|
||||
# After this, values are in a range where block_max / 6.0 ≤ 448.0,
|
||||
# so block scales fit in UE4M3 without saturation.
|
||||
x_normalized = x_f32 / input_global_scale
|
||||
|
||||
x_fp4, x_sf = _quantize_to_e2m1(x_normalized)
|
||||
|
||||
# Fold global scale into block scales: sf_folded = sf * global_scale
|
||||
# (same as _fold_global_scale in weight_transform.py)
|
||||
sf_folded = x_sf.to(torch.float32) * input_global_scale
|
||||
x_sf_folded = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
|
||||
return x_fp4, x_sf_folded
|
||||
|
||||
|
||||
def nvfp4_mega_moe_full(
|
||||
|
||||
Reference in New Issue
Block a user