debug: add FP4 quantization round-trip diagnostic

This commit is contained in:
2026-05-15 13:41:09 +00:00
parent d0ed3d84a8
commit 995589ac8a

View File

@@ -322,6 +322,28 @@ def nvfp4_mega_moe_full(
x_fp4 = symm_buffer.x[:num_tokens]
x_sf = symm_buffer.x_sf[:num_tokens]
l1_global_scale = symm_buffer.input_global_scale
# Diagnostic: check FP4 quantization quality by dequantizing and comparing
if not getattr(nvfp4_mega_moe_full, '_quant_diag', False):
nvfp4_mega_moe_full._quant_diag = True
# Dequantize: FP4 → BF16 round-trip check
x_u8 = x_fp4.view(torch.uint8)
lo = (x_u8 & 0x0F).to(torch.int8) # low nibble
hi = ((x_u8 >> 4) & 0x0F).to(torch.int8) # high nibble
# Interleave back to (num_tokens, K)
x_nibbles = torch.stack([lo, hi], dim=-1).reshape(num_tokens, -1) # (T, K)
signs = (x_nibbles >> 3).float() * -2 + 1 # +1 or -1
mags = _E2M1_MAGNITUDES.to(device=x_nibbles.device)[(x_nibbles & 0x07).long()]
x_deq = signs * mags # (T, K) in E2M1 magnitudes
# Apply block scales and global scale
sf_expanded = x_sf.to(torch.float32).repeat_interleave(16, dim=-1) # (T, K)
igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
x_reconstructed = x_deq * sf_expanded * igs
print(f"[QUANT-DIAG] x_fp4 amax={x_fp4.view(torch.uint8).float().amax():.0f} "
f"x_sf range=[{x_sf.to(torch.float32).min():.2f}, {x_sf.to(torch.float32).max():.2f}] "
f"igs={igs:.4e}")
print(f"[QUANT-DIAG] reconstructed amax={x_reconstructed.abs().max():.4e} "
f"mean={x_reconstructed.mean():.4e}")
topk_ids = symm_buffer.topk_idx[:num_tokens]
topk_weights = symm_buffer.topk_weights[:num_tokens]