fix: zero out x_norm for underflow blocks before division in NVFP4 quantization
Bug #4 fix: When a block has amax > 0 but amax/6 underflows to 0 in FP8 (amax < 6*2^-9 ≈ 0.0117), the block scale is 0, but the division x / clamp(0, 1e-8) inflates x into nonzero FP4 buckets (up to ±6.0). This produces semantically wrong FP4 even though dequant gives 0 (6*0=0). Root cause: we only detected truly-zero blocks (amax == 0) but not underflow blocks (0 < amax < FP8_threshold). The fix: 1. Detect both zero and underflow blocks: block_amax < 6 * 2^-9 2. Zero out x_reshaped for these blocks BEFORE division 3. Force FP8 scale to 0 for these blocks This ensures x_scaled = 0 → FP4 nibbles = 0 → dequant = 0. Verified: bug scenario now produces nibble=0, scale=0. Checkpoint byte match remains 100%.
This commit is contained in:
@@ -99,13 +99,18 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
|
||||
|
||||
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
||||
block_amax = x_reshaped.abs().amax(dim=-1)
|
||||
zero_block = block_amax == 0
|
||||
# Clamp for safe division; zero blocks get scale=0 after the where.
|
||||
# Detect zero blocks and underflow blocks (amax > 0 but too small for FP8).
|
||||
# Smallest positive FP8 e4m3fn is 2^-9 ≈ 1.95e-3. If amax/6 < this,
|
||||
# the block scale underflows to 0, and dividing x by the clamped 1e-8
|
||||
# inflates values into nonzero FP4 buckets — producing wrong results.
|
||||
zero_block = block_amax < (6.0 * 2.0 ** -9) # < ~0.0117
|
||||
# Zero out x for zero/underflow blocks before division.
|
||||
# This ensures x_scaled = 0 → FP4 nibbles = 0.
|
||||
x_reshaped = torch.where(zero_block.unsqueeze(-1),
|
||||
torch.zeros_like(x_reshaped), x_reshaped)
|
||||
block_amax = block_amax.clamp(min=1e-8)
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
# Zero blocks: force both FP4 nibbles and FP8 scale to exact zero.
|
||||
# Without this, the FP8 scale underflows to 0 anyway for small blocks
|
||||
# but the div-by-tiny-number can produce nonzero FP4 from noise.
|
||||
# Force zero/underflow blocks: FP8 scale = 0 (exact zero).
|
||||
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
|
||||
|
||||
# Nearest E2M1
|
||||
@@ -162,10 +167,12 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
|
||||
|
||||
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
||||
block_amax = x_reshaped.abs().amax(dim=-1)
|
||||
zero_block = block_amax == 0
|
||||
# Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4).
|
||||
zero_block = block_amax < (6.0 * 2.0 ** -9)
|
||||
x_reshaped = torch.where(zero_block.unsqueeze(-1),
|
||||
torch.zeros_like(x_reshaped), x_reshaped)
|
||||
block_amax = block_amax.clamp(min=1e-8)
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
# Zero blocks: force both FP4 nibbles and FP8 scale to exact zero.
|
||||
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
|
||||
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
@@ -218,10 +225,12 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
|
||||
|
||||
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
|
||||
w_block_amax = w_reshaped.abs().amax(dim=1)
|
||||
zero_block = w_block_amax == 0
|
||||
# Detect zero blocks and underflow blocks (same threshold).
|
||||
zero_block = w_block_amax < (6.0 * 2.0 ** -9)
|
||||
w_reshaped = torch.where(zero_block.unsqueeze(1),
|
||||
torch.zeros_like(w_reshaped), w_reshaped)
|
||||
w_block_amax = w_block_amax.clamp(min=1e-8)
|
||||
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
# Zero blocks: force both FP4 nibbles and FP8 scale to exact zero.
|
||||
w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf)
|
||||
|
||||
w_block_sf = w_sf.float().unsqueeze(1)
|
||||
|
||||
Reference in New Issue
Block a user