#!/usr/bin/env python3 """Sanity check an NVFP4 DeepSeek V4 Pro checkpoint. Two modes: 1) --tensor-only (default): no model loading. Just inspects the safetensors shards: confirms NVFP4 packing structure (uint8 weight + FP8 weight_scale + FP32 weight_scale_2), checks for NaN/Inf in scales, samples a few dequantizations to confirm they look plausible. 2) --vllm: tries to load the model with vLLM and generate a few tokens. Requires vLLM with NVFP4 support (SM100+ Blackwell GPU). Usage: python verify_nvfp4.py DeepSeek-V4-Pro-NVFP4-streaming python verify_nvfp4.py DeepSeek-V4-Pro-NVFP4-streaming --vllm """ import argparse import json import sys from pathlib import Path import torch from safetensors import safe_open FP4_E2M1_VALUES = torch.tensor( [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], dtype=torch.float32, ) def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: """Reverse the (low | high<<4) byte pack into a [M, N] tensor of FP4 indices.""" low = packed & 0x0F high = (packed >> 4) & 0x0F M, N_half = packed.shape out = torch.empty(M, N_half * 2, dtype=torch.uint8) out[:, ::2] = low out[:, 1::2] = high return out def dequant_nvfp4(packed_uint8, weight_scale_fp8, weight_scale_2_fp32): """Reconstruct FP32 values from NVFP4 storage.""" fp4_idx = unpack_fp4(packed_uint8) values = FP4_E2M1_VALUES[fp4_idx.long()] # [M, N] M, N = values.shape # Per-block scale broadcast back over 16 elements scale_blocks = weight_scale_fp8.float() # [M, N//16] scale_per_elem = scale_blocks.unsqueeze(-1).expand(-1, -1, 16).reshape(M, N) return values * scale_per_elem * weight_scale_2_fp32.float() def tensor_only_check(model_dir: Path): index_path = model_dir / "model.safetensors.index.json" if not index_path.exists(): sys.exit(f"No index.json at {model_dir}") with open(index_path) as f: index = json.load(f) weight_map = index["weight_map"] # Find one quantized weight to sample sample = None for name, fn in weight_map.items(): if name.endswith(".weight") and (name.replace(".weight", ".weight_scale") in weight_map): sample = name break if not sample: sys.exit("Couldn't find an NVFP4-quantized weight (expected *.weight_scale companion).") print(f"Sampling: {sample}") shard_fn = weight_map[sample] scale_name = sample.replace(".weight", ".weight_scale") scale_2_name = sample.replace(".weight", ".weight_scale_2") scale_shard = weight_map[scale_name] scale_2_shard = weight_map[scale_2_name] def open_get(fn, name): with safe_open(model_dir / fn, framework="pt") as f: return f.get_tensor(name) packed = open_get(shard_fn, sample) weight_scale = open_get(scale_shard, scale_name) weight_scale_2 = open_get(scale_2_shard, scale_2_name) print(f" packed: shape={tuple(packed.shape)} dtype={packed.dtype}") print(f" weight_scale: shape={tuple(weight_scale.shape)} dtype={weight_scale.dtype}") print(f" weight_scale_2: shape={tuple(weight_scale_2.shape)} dtype={weight_scale_2.dtype} " f"value={weight_scale_2.float().item():.6e}") # Structural assertions M = packed.shape[0] assert packed.dtype == torch.uint8, f"packed should be uint8, got {packed.dtype}" assert weight_scale.dtype == torch.float8_e4m3fn, \ f"weight_scale should be FP8 E4M3, got {weight_scale.dtype}" assert weight_scale.shape == (M, packed.shape[1] * 2 // 16), \ f"weight_scale shape {weight_scale.shape} doesn't match expected (M, N/16)" # Check for NaN/Inf in scales s_fp32 = weight_scale.float() assert torch.isfinite(s_fp32).all(), "weight_scale contains NaN/Inf" assert torch.isfinite(weight_scale_2.float()).all(), "weight_scale_2 is NaN/Inf" print(f" scales: all finite ✓") print(f" weight_scale stats: min={s_fp32.min().item():.3e} max={s_fp32.max().item():.3e} " f"mean={s_fp32.mean().item():.3e}") # Spot-check dequantization print("\nDequantizing first 4x32 block for visual check:") rec = dequant_nvfp4(packed[:4, :16], weight_scale[:4, :2], weight_scale_2) print(rec) assert torch.isfinite(rec).all(), "Dequantized values contain NaN/Inf" print(f" dequant: all finite ✓") print(f" dequant range: [{rec.min().item():.4f}, {rec.max().item():.4f}]") # Count what's quantized vs preserved across the whole model quantized_weights = [] preserved = [] for name in weight_map: if name.endswith(".weight"): if name.replace(".weight", ".weight_scale") in weight_map: quantized_weights.append(name) else: preserved.append(name) print(f"\nWhole-model summary:") print(f" Quantized .weight tensors: {len(quantized_weights):,}") print(f" Preserved .weight tensors: {len(preserved):,}") print(f" Total tensors in index: {len(weight_map):,}") # Show a few preserved names to confirm the right things stayed in higher precision print(f"\n Sample preserved tensors (should be lm_head, embed, gates, norms, etc.):") for n in preserved[:10]: print(f" {n}") def vllm_check(model_dir: Path): print("Loading model with vLLM... (requires Blackwell GPU + vLLM with NVFP4 support)") from vllm import LLM, SamplingParams llm = LLM( model=str(model_dir), trust_remote_code=True, quantization="compressed-tensors", dtype="auto", tensor_parallel_size=8, max_model_len=8192, ) sampling = SamplingParams(temperature=1.0, top_p=1.0, max_tokens=64) prompts = [ "Write a short poem about quantization:", "What is 17 * 23?", "Explain MoE routing in one sentence.", ] outputs = llm.generate(prompts, sampling) for o in outputs: print("=" * 60) print("PROMPT:", o.prompt) print("OUTPUT:", o.outputs[0].text) def main(): ap = argparse.ArgumentParser() ap.add_argument("model_dir") ap.add_argument("--vllm", action="store_true") args = ap.parse_args() model_dir = Path(args.model_dir) tensor_only_check(model_dir) if args.vllm: print("\n" + "=" * 60) vllm_check(model_dir) if __name__ == "__main__": main()