From c8fa87fac73d94e31d4fcb92525e696ea5234390 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 20 May 2026 02:14:50 +0000 Subject: [PATCH] fix: detect zero blocks in NVFP4 quantization, force FP4+FP8 to exact zero Bug #3 fix: The clamp(min=1e-8) on block_amax prevented NaN from 0/0 but allowed truly-zero blocks to get a nonzero FP8 scale (5e-12 from underflow). While the kernel produces 0 * 0 = 0 (no NaN), the nonzero scale is semantically wrong and could interact badly with future kernels. Fix: detect zero blocks explicitly (block_amax == 0), clamp only for safe division, then force FP8 scale to exact zero for zero blocks via torch.where. The FP4 nibbles are already zero (0 / anything = 0). Verified: checkpoint byte match remains 100%, zero blocks produce exact-zero dequantization, no NaN propagation. Applies to all three quantization functions: - quantize_to_nvfp4 (activation with computed gs) - quantize_activation_nvfp4 (activation with pre-computed gs) - quantize_weight_to_nvfp4 (weight quantization) --- cutedsl/bridge.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 07e3048b..b20a2deb 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -98,8 +98,15 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): x_norm = torch.nn.functional.pad(x_norm, (0, pad_size)) x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) - block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8) + block_amax = x_reshaped.abs().amax(dim=-1) + zero_block = block_amax == 0 + # Clamp for safe division; zero blocks get scale=0 after the where. + block_amax = block_amax.clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) + # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. + # Without this, the FP8 scale underflows to 0 anyway for small blocks + # but the div-by-tiny-number can produce nonzero FP4 from noise. + block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) # Nearest E2M1 block_sf_expanded = block_scale.float().unsqueeze(-1) @@ -154,8 +161,12 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE): x_norm = torch.nn.functional.pad(x_norm, (0, pad_size)) x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) - block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8) + block_amax = x_reshaped.abs().amax(dim=-1) + zero_block = block_amax == 0 + block_amax = block_amax.clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) + # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. + block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) block_sf_expanded = block_scale.float().unsqueeze(-1) x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8) @@ -206,8 +217,12 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE): w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K)) w_reshaped = w_norm.reshape(k_blocks, block_size, N) - w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8) + w_block_amax = w_reshaped.abs().amax(dim=1) + zero_block = w_block_amax == 0 + w_block_amax = w_block_amax.clamp(min=1e-8) w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn) + # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. + w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf) w_block_sf = w_sf.float().unsqueeze(1) w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)