fix: zero out x_norm for underflow blocks before division in NVFP4 quantization

Bug #4 fix: When a block has amax > 0 but amax/6 underflows to 0 in
FP8 (amax < 6*2^-9 ≈ 0.0117), the block scale is 0, but the division
x / clamp(0, 1e-8) inflates x into nonzero FP4 buckets (up to ±6.0).
This produces semantically wrong FP4 even though dequant gives 0 (6*0=0).

Root cause: we only detected truly-zero blocks (amax == 0) but not
underflow blocks (0 < amax < FP8_threshold). The fix:

1. Detect both zero and underflow blocks: block_amax < 6 * 2^-9
2. Zero out x_reshaped for these blocks BEFORE division
3. Force FP8 scale to 0 for these blocks

This ensures x_scaled = 0 → FP4 nibbles = 0 → dequant = 0.
Verified: bug scenario now produces nibble=0, scale=0.
Checkpoint byte match remains 100%.
This commit is contained in:
2026-05-20 02:16:49 +00:00
parent e653712598
commit 4882d8553c

View File

@@ -99,13 +99,18 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1)
zero_block = block_amax == 0
# Clamp for safe division; zero blocks get scale=0 after the where.
# Detect zero blocks and underflow blocks (amax > 0 but too small for FP8).
# Smallest positive FP8 e4m3fn is 2^-9 ≈ 1.95e-3. If amax/6 < this,
# the block scale underflows to 0, and dividing x by the clamped 1e-8
# inflates values into nonzero FP4 buckets — producing wrong results.
zero_block = block_amax < (6.0 * 2.0 ** -9) # < ~0.0117
# Zero out x for zero/underflow blocks before division.
# This ensures x_scaled = 0 → FP4 nibbles = 0.
x_reshaped = torch.where(zero_block.unsqueeze(-1),
torch.zeros_like(x_reshaped), x_reshaped)
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Zero blocks: force both FP4 nibbles and FP8 scale to exact zero.
# Without this, the FP8 scale underflows to 0 anyway for small blocks
# but the div-by-tiny-number can produce nonzero FP4 from noise.
# Force zero/underflow blocks: FP8 scale = 0 (exact zero).
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
# Nearest E2M1
@@ -162,10 +167,12 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1)
zero_block = block_amax == 0
# Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4).
zero_block = block_amax < (6.0 * 2.0 ** -9)
x_reshaped = torch.where(zero_block.unsqueeze(-1),
torch.zeros_like(x_reshaped), x_reshaped)
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Zero blocks: force both FP4 nibbles and FP8 scale to exact zero.
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
block_sf_expanded = block_scale.float().unsqueeze(-1)
@@ -218,10 +225,12 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
w_block_amax = w_reshaped.abs().amax(dim=1)
zero_block = w_block_amax == 0
# Detect zero blocks and underflow blocks (same threshold).
zero_block = w_block_amax < (6.0 * 2.0 ** -9)
w_reshaped = torch.where(zero_block.unsqueeze(1),
torch.zeros_like(w_reshaped), w_reshaped)
w_block_amax = w_block_amax.clamp(min=1e-8)
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
# Zero blocks: force both FP4 nibbles and FP8 scale to exact zero.
w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf)
w_block_sf = w_sf.float().unsqueeze(1)