diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index b20a2deb..2849ceb3 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -99,13 +99,18 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) block_amax = x_reshaped.abs().amax(dim=-1) - zero_block = block_amax == 0 - # Clamp for safe division; zero blocks get scale=0 after the where. + # Detect zero blocks and underflow blocks (amax > 0 but too small for FP8). + # Smallest positive FP8 e4m3fn is 2^-9 ≈ 1.95e-3. If amax/6 < this, + # the block scale underflows to 0, and dividing x by the clamped 1e-8 + # inflates values into nonzero FP4 buckets — producing wrong results. + zero_block = block_amax < (6.0 * 2.0 ** -9) # < ~0.0117 + # Zero out x for zero/underflow blocks before division. + # This ensures x_scaled = 0 → FP4 nibbles = 0. + x_reshaped = torch.where(zero_block.unsqueeze(-1), + torch.zeros_like(x_reshaped), x_reshaped) block_amax = block_amax.clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. - # Without this, the FP8 scale underflows to 0 anyway for small blocks - # but the div-by-tiny-number can produce nonzero FP4 from noise. + # Force zero/underflow blocks: FP8 scale = 0 (exact zero). block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) # Nearest E2M1 @@ -162,10 +167,12 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE): x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) block_amax = x_reshaped.abs().amax(dim=-1) - zero_block = block_amax == 0 + # Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4). + zero_block = block_amax < (6.0 * 2.0 ** -9) + x_reshaped = torch.where(zero_block.unsqueeze(-1), + torch.zeros_like(x_reshaped), x_reshaped) block_amax = block_amax.clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) block_sf_expanded = block_scale.float().unsqueeze(-1) @@ -218,10 +225,12 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE): w_reshaped = w_norm.reshape(k_blocks, block_size, N) w_block_amax = w_reshaped.abs().amax(dim=1) - zero_block = w_block_amax == 0 + # Detect zero blocks and underflow blocks (same threshold). + zero_block = w_block_amax < (6.0 * 2.0 ** -9) + w_reshaped = torch.where(zero_block.unsqueeze(1), + torch.zeros_like(w_reshaped), w_reshaped) w_block_amax = w_block_amax.clamp(min=1e-8) w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn) - # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf) w_block_sf = w_sf.float().unsqueeze(1)