"""Test: Check if dequantize→requantize preserves checkpoint FP4 weights. If the rounding convention differs (round-half-up vs round-half-to-even), the FP4 buckets shift and accumulated error across 1.6T weights matters. This test: 1. Loads a single expert's checkpoint FP4 weights 2. Dequantizes them to BF16 (using the checkpoint's scales/gs) 3. Re-quantizes BF16 → FP4 (using our quantize_weight_to_nvfp4) 4. Dequantizes the re-quantized weights back to BF16 5. Compares: ||W_ours - W_checkpoint||_F / ||W_checkpoint||_F If > 1e-3, there's a rounding mismatch that matters. """ import sys import torch def dequantize_nvfp4_weight(packed_uint8, scale_e4m3, global_scale): """Dequantize NVFP4 weight tensor to BF16. Args: packed_uint8: (N, K_packed) uint8 — packed FP4 bytes scale_e4m3: (N, K_sf) float8_e4m3fn — block scales global_scale: float32 scalar Returns: (N, K_original) BF16 weight matrix """ # Unpack FP4 → BF16 raw = packed_uint8.view(torch.uint8) low = (raw & 0x0F).to(torch.int8) # even elements high = ((raw >> 4) & 0x0F).to(torch.int8) # odd elements # E2M1 magnitudes: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] e2m1_values = torch.tensor([0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32, device=raw.device) # Sign: bit 3 = 1 means negative low_sign = (low >> 3).bool() low_idx = (low & 0x07) high_sign = (high >> 3).bool() high_idx = (high & 0x07) low_mag = e2m1_values[low_idx.long()] high_mag = e2m1_values[high_idx.long()] low_val = torch.where(low_sign, -low_mag, low_mag) high_val = torch.where(high_sign, -high_mag, high_mag) # Interleave low and high N, K_packed = packed_uint8.shape K = K_packed * 2 values = torch.stack([low_val, high_val], dim=-1).reshape(N, K) # Dequantize: value * block_scale * global_scale K_sf = scale_e4m3.shape[1] block_size = K // K_sf # Should be 16 scale_expanded = scale_e4m3.float().unsqueeze(2).expand(-1, -1, block_size).reshape(N, K) dequant = values * scale_expanded * global_scale return dequant.to(torch.bfloat16) def test_roundtrip(): from safetensors.torch import load_file import glob # Load checkpoint files = sorted(glob.glob('/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4/*.safetensors')) if not files: print("No safetensor files found") return # Find a file with expert 0 weights d = load_file(files[0]) # Test gate_proj of expert 0 gate_w = d['model.layers.0.mlp.experts.0.gate_proj.weight'].cuda() # (3072, 3584) uint8 gate_sf = d['model.layers.0.mlp.experts.0.gate_proj.weight_scale'].cuda() # (3072, 448) fp8 gate_gs = d['model.layers.0.mlp.experts.0.gate_proj.weight_scale_2'].item() # float32 print(f"Checkpoint: gate_proj shape={gate_w.shape}, sf shape={gate_sf.shape}, gs={gate_gs}") # Step 1: Dequantize checkpoint → BF16 gate_bf16 = dequantize_nvfp4_weight(gate_w, gate_sf, gate_gs) print(f"Dequantized: shape={gate_bf16.shape}, amax={gate_bf16.abs().amax().item():.6f}") # Step 2: Re-quantize BF16 → FP4 using our convention sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel') from dsv4.ops.quantize import ( quantize_weight_to_nvfp4, ) # quantize_weight_to_nvfp4 expects (K, N) where K is the packed dim # Our gate is (3072, 7168) in BF16, so K=3072, N=7168 # But the checkpoint stores it as (3072, 3584) uint8 = (3072, 7168//2) packed # The dequantized shape is (3072, 7168) BF16 # quantize_weight_to_nvfp4 expects (K, N) = (3072, 7168) w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(gate_bf16) print(f"Re-quantized: fp4 shape={w_fp4.shape}, sf shape={w_sf.shape}, gs={w_gs}") # Step 3: Dequantize the re-quantized weights gate_bf16_ours = dequantize_nvfp4_weight( w_fp4.view(torch.uint8) if w_fp4.dtype != torch.uint8 else w_fp4, w_sf, w_gs, ) # Step 4: Compare diff = (gate_bf16_ours - gate_bf16).float() rel_err = diff.norm() / gate_bf16.float().norm() max_err = diff.abs().max() print(f"\n=== Results ===") print(f"Relative error (Frobenius): {rel_err.item():.6f}") print(f"Max absolute error: {max_err.item():.6f}") print(f"Threshold: 1e-3 = {rel_err.item() > 1e-3}") # Step 5: Compare raw FP4 bytes our_uint8 = w_fp4.view(torch.uint8) if w_fp4.dtype != torch.uint8 else w_fp4 byte_match = (our_uint8 == gate_w).float().mean() print(f"FP4 byte match rate: {byte_match.item():.4f}") # Step 6: Check where they differ if byte_match.item() < 1.0: mismatch = (our_uint8 != gate_w) mismatch_count = mismatch.sum().item() total = gate_w.numel() print(f"Byte mismatches: {mismatch_count}/{total} ({mismatch_count/total*100:.2f}%)") # Sample some mismatches idx = mismatch.nonzero()[:5] for i in range(min(5, len(idx))): r, c = idx[i].tolist() ours = our_uint8[r, c].item() theirs = gate_w[r, c].item() print(f" [{r},{c}] ours=0x{ours:02x} checkpoint=0x{theirs:02x}") # Step 7: Also compare scales if w_sf.shape == gate_sf.shape and w_sf.dtype == gate_sf.dtype: sf_match = (w_sf == gate_sf).float().mean() print(f"Block scale match rate: {sf_match.item():.4f}") print(f"\nGlobal scale: ours={w_gs:.8f}, checkpoint={gate_gs:.8f}, diff={abs(w_gs - gate_gs):.8f}") if __name__ == "__main__": test_roundtrip()