From 995589ac8a93c83495a538dbaed9a38e60241f2f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 13:41:09 +0000 Subject: [PATCH] debug: add FP4 quantization round-trip diagnostic --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 206c777f..ae9eee05 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -322,6 +322,28 @@ def nvfp4_mega_moe_full( x_fp4 = symm_buffer.x[:num_tokens] x_sf = symm_buffer.x_sf[:num_tokens] l1_global_scale = symm_buffer.input_global_scale + + # Diagnostic: check FP4 quantization quality by dequantizing and comparing + if not getattr(nvfp4_mega_moe_full, '_quant_diag', False): + nvfp4_mega_moe_full._quant_diag = True + # Dequantize: FP4 → BF16 round-trip check + x_u8 = x_fp4.view(torch.uint8) + lo = (x_u8 & 0x0F).to(torch.int8) # low nibble + hi = ((x_u8 >> 4) & 0x0F).to(torch.int8) # high nibble + # Interleave back to (num_tokens, K) + x_nibbles = torch.stack([lo, hi], dim=-1).reshape(num_tokens, -1) # (T, K) + signs = (x_nibbles >> 3).float() * -2 + 1 # +1 or -1 + mags = _E2M1_MAGNITUDES.to(device=x_nibbles.device)[(x_nibbles & 0x07).long()] + x_deq = signs * mags # (T, K) in E2M1 magnitudes + # Apply block scales and global scale + sf_expanded = x_sf.to(torch.float32).repeat_interleave(16, dim=-1) # (T, K) + igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale + x_reconstructed = x_deq * sf_expanded * igs + print(f"[QUANT-DIAG] x_fp4 amax={x_fp4.view(torch.uint8).float().amax():.0f} " + f"x_sf range=[{x_sf.to(torch.float32).min():.2f}, {x_sf.to(torch.float32).max():.2f}] " + f"igs={igs:.4e}") + print(f"[QUANT-DIAG] reconstructed amax={x_reconstructed.abs().max():.4e} " + f"mean={x_reconstructed.mean():.4e}") topk_ids = symm_buffer.topk_idx[:num_tokens] topk_weights = symm_buffer.topk_weights[:num_tokens]