debug: add FP4 quantization round-trip diagnostic
This commit is contained in:
@@ -322,6 +322,28 @@ def nvfp4_mega_moe_full(
|
||||
x_fp4 = symm_buffer.x[:num_tokens]
|
||||
x_sf = symm_buffer.x_sf[:num_tokens]
|
||||
l1_global_scale = symm_buffer.input_global_scale
|
||||
|
||||
# Diagnostic: check FP4 quantization quality by dequantizing and comparing
|
||||
if not getattr(nvfp4_mega_moe_full, '_quant_diag', False):
|
||||
nvfp4_mega_moe_full._quant_diag = True
|
||||
# Dequantize: FP4 → BF16 round-trip check
|
||||
x_u8 = x_fp4.view(torch.uint8)
|
||||
lo = (x_u8 & 0x0F).to(torch.int8) # low nibble
|
||||
hi = ((x_u8 >> 4) & 0x0F).to(torch.int8) # high nibble
|
||||
# Interleave back to (num_tokens, K)
|
||||
x_nibbles = torch.stack([lo, hi], dim=-1).reshape(num_tokens, -1) # (T, K)
|
||||
signs = (x_nibbles >> 3).float() * -2 + 1 # +1 or -1
|
||||
mags = _E2M1_MAGNITUDES.to(device=x_nibbles.device)[(x_nibbles & 0x07).long()]
|
||||
x_deq = signs * mags # (T, K) in E2M1 magnitudes
|
||||
# Apply block scales and global scale
|
||||
sf_expanded = x_sf.to(torch.float32).repeat_interleave(16, dim=-1) # (T, K)
|
||||
igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
|
||||
x_reconstructed = x_deq * sf_expanded * igs
|
||||
print(f"[QUANT-DIAG] x_fp4 amax={x_fp4.view(torch.uint8).float().amax():.0f} "
|
||||
f"x_sf range=[{x_sf.to(torch.float32).min():.2f}, {x_sf.to(torch.float32).max():.2f}] "
|
||||
f"igs={igs:.4e}")
|
||||
print(f"[QUANT-DIAG] reconstructed amax={x_reconstructed.abs().max():.4e} "
|
||||
f"mean={x_reconstructed.mean():.4e}")
|
||||
topk_ids = symm_buffer.topk_idx[:num_tokens]
|
||||
topk_weights = symm_buffer.topk_weights[:num_tokens]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user