From 4882d8553cc028a8d4b1a90877ec388ed58beae0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 20 May 2026 02:16:49 +0000 Subject: [PATCH] fix: zero out x_norm for underflow blocks before division in NVFP4 quantization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug #4 fix: When a block has amax > 0 but amax/6 underflows to 0 in FP8 (amax < 6*2^-9 ≈ 0.0117), the block scale is 0, but the division x / clamp(0, 1e-8) inflates x into nonzero FP4 buckets (up to ±6.0). This produces semantically wrong FP4 even though dequant gives 0 (6*0=0). Root cause: we only detected truly-zero blocks (amax == 0) but not underflow blocks (0 < amax < FP8_threshold). The fix: 1. Detect both zero and underflow blocks: block_amax < 6 * 2^-9 2. Zero out x_reshaped for these blocks BEFORE division 3. Force FP8 scale to 0 for these blocks This ensures x_scaled = 0 → FP4 nibbles = 0 → dequant = 0. Verified: bug scenario now produces nibble=0, scale=0. Checkpoint byte match remains 100%. --- cutedsl/bridge.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index b20a2deb..2849ceb3 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -99,13 +99,18 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) block_amax = x_reshaped.abs().amax(dim=-1) - zero_block = block_amax == 0 - # Clamp for safe division; zero blocks get scale=0 after the where. + # Detect zero blocks and underflow blocks (amax > 0 but too small for FP8). + # Smallest positive FP8 e4m3fn is 2^-9 ≈ 1.95e-3. If amax/6 < this, + # the block scale underflows to 0, and dividing x by the clamped 1e-8 + # inflates values into nonzero FP4 buckets — producing wrong results. + zero_block = block_amax < (6.0 * 2.0 ** -9) # < ~0.0117 + # Zero out x for zero/underflow blocks before division. + # This ensures x_scaled = 0 → FP4 nibbles = 0. + x_reshaped = torch.where(zero_block.unsqueeze(-1), + torch.zeros_like(x_reshaped), x_reshaped) block_amax = block_amax.clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. - # Without this, the FP8 scale underflows to 0 anyway for small blocks - # but the div-by-tiny-number can produce nonzero FP4 from noise. + # Force zero/underflow blocks: FP8 scale = 0 (exact zero). block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) # Nearest E2M1 @@ -162,10 +167,12 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE): x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) block_amax = x_reshaped.abs().amax(dim=-1) - zero_block = block_amax == 0 + # Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4). + zero_block = block_amax < (6.0 * 2.0 ** -9) + x_reshaped = torch.where(zero_block.unsqueeze(-1), + torch.zeros_like(x_reshaped), x_reshaped) block_amax = block_amax.clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) block_sf_expanded = block_scale.float().unsqueeze(-1) @@ -218,10 +225,12 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE): w_reshaped = w_norm.reshape(k_blocks, block_size, N) w_block_amax = w_reshaped.abs().amax(dim=1) - zero_block = w_block_amax == 0 + # Detect zero blocks and underflow blocks (same threshold). + zero_block = w_block_amax < (6.0 * 2.0 ** -9) + w_reshaped = torch.where(zero_block.unsqueeze(1), + torch.zeros_like(w_reshaped), w_reshaped) w_block_amax = w_block_amax.clamp(min=1e-8) w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn) - # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf) w_block_sf = w_sf.float().unsqueeze(1)