fix: detect zero blocks in NVFP4 quantization, force FP4+FP8 to exact zero

Bug #3 fix: The clamp(min=1e-8) on block_amax prevented NaN from 0/0
but allowed truly-zero blocks to get a nonzero FP8 scale (5e-12 from
underflow). While the kernel produces 0 * 0 = 0 (no NaN), the nonzero
scale is semantically wrong and could interact badly with future kernels.

Fix: detect zero blocks explicitly (block_amax == 0), clamp only for
safe division, then force FP8 scale to exact zero for zero blocks via
torch.where. The FP4 nibbles are already zero (0 / anything = 0).

Verified: checkpoint byte match remains 100%, zero blocks produce
exact-zero dequantization, no NaN propagation.

Applies to all three quantization functions:
- quantize_to_nvfp4 (activation with computed gs)
- quantize_activation_nvfp4 (activation with pre-computed gs)
- quantize_weight_to_nvfp4 (weight quantization)
This commit is contained in:
2026-05-20 02:14:50 +00:00
parent 3c6b5a0522
commit c8fa87fac7

View File

@@ -98,8 +98,15 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
block_amax = x_reshaped.abs().amax(dim=-1)
zero_block = block_amax == 0
# Clamp for safe division; zero blocks get scale=0 after the where.
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Zero blocks: force both FP4 nibbles and FP8 scale to exact zero.
# Without this, the FP8 scale underflows to 0 anyway for small blocks
# but the div-by-tiny-number can produce nonzero FP4 from noise.
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
# Nearest E2M1
block_sf_expanded = block_scale.float().unsqueeze(-1)
@@ -154,8 +161,12 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
block_amax = x_reshaped.abs().amax(dim=-1)
zero_block = block_amax == 0
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Zero blocks: force both FP4 nibbles and FP8 scale to exact zero.
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
@@ -206,8 +217,12 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K))
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8)
w_block_amax = w_reshaped.abs().amax(dim=1)
zero_block = w_block_amax == 0
w_block_amax = w_block_amax.clamp(min=1e-8)
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
# Zero blocks: force both FP4 nibbles and FP8 scale to exact zero.
w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf)
w_block_sf = w_sf.float().unsqueeze(1)
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)