diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 07e3048b..b20a2deb 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -98,8 +98,15 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): x_norm = torch.nn.functional.pad(x_norm, (0, pad_size)) x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) - block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8) + block_amax = x_reshaped.abs().amax(dim=-1) + zero_block = block_amax == 0 + # Clamp for safe division; zero blocks get scale=0 after the where. + block_amax = block_amax.clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) + # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. + # Without this, the FP8 scale underflows to 0 anyway for small blocks + # but the div-by-tiny-number can produce nonzero FP4 from noise. + block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) # Nearest E2M1 block_sf_expanded = block_scale.float().unsqueeze(-1) @@ -154,8 +161,12 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE): x_norm = torch.nn.functional.pad(x_norm, (0, pad_size)) x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) - block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8) + block_amax = x_reshaped.abs().amax(dim=-1) + zero_block = block_amax == 0 + block_amax = block_amax.clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) + # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. + block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) block_sf_expanded = block_scale.float().unsqueeze(-1) x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8) @@ -206,8 +217,12 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE): w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K)) w_reshaped = w_norm.reshape(k_blocks, block_size, N) - w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8) + w_block_amax = w_reshaped.abs().amax(dim=1) + zero_block = w_block_amax == 0 + w_block_amax = w_block_amax.clamp(min=1e-8) w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn) + # Zero blocks: force both FP4 nibbles and FP8 scale to exact zero. + w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf) w_block_sf = w_sf.float().unsqueeze(1) w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)