#!/usr/bin/env python3 """Compare production NVFP4 GEMM vs PyTorch reference dequant at specific layers. This test loads a single layer's weights and compares the production Nvfp4Linear output against the PyTorch F.linear(dequant_nvfp4) reference. This is a diagnostic test to identify where the production kernel diverges from the reference, causing the residual growth issue. """ import os, sys, json, math, torch, torch.nn.functional as F from pathlib import Path CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4") FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None): O, I2 = weight.shape; I = I2 * 2 lo = (weight & 0x0F).to(torch.int8); hi = (weight >> 4).to(torch.int8) lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.) hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.) w = torch.stack([lo_f, hi_f], -1).reshape(O, I) s = weight_scale.float().repeat_interleave(16, 1) if weight_scale_2 is not None: s = s * weight_scale_2.float() return (w * s).bfloat16() def get_nvfp4_weight(w, pfx, proj_name): k = f"{pfx}.{proj_name}" return (w.get(f"{k}.weight"), w.get(f"{k}.weight_scale"), w.get(f"{k}.weight_scale_2"), w.get(f"{k}.input_scale")) def main(): device = "cuda:0" torch.manual_seed(42) # Load config with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) H = cfg["hidden_size"] # Load weights from safetensors.torch import load_file cdir = Path(CHECKPOINT_DIR); wmap = {} idx = cdir / "model.safetensors.index.json" if idx.exists(): with open(idx) as f: wmap = json.load(f).get("weight_map", {}) shards = set(wmap.values()) if wmap else set(); all_w = {} for sn in sorted(shards): if (cdir / sn).exists(): all_w.update(load_file(str(cdir / sn))) print(f"Loaded {len(all_w)} tensors") # Import production kernel from dsv4.layers.linear import Nvfp4Linear # Test projections at different layers test_cases = [ # (layer_idx, proj_name, in_features, out_features) (0, "model.layers.0.self_attn.q_a_proj", 7168, 1536), (0, "model.layers.0.self_attn.kv_proj", 7168, 512), (0, "model.layers.0.self_attn.q_b_proj", 1536, 65536), (0, "model.layers.0.self_attn.o_b_proj", 16384, 7168), (30, "model.layers.30.self_attn.q_a_proj", 7168, 1536), (60, "model.layers.60.self_attn.q_a_proj", 7168, 1536), (60, "model.layers.60.self_attn.kv_proj", 7168, 512), # Router gate (3, "model.layers.3.mlp.gate", 7168, 384), (30, "model.layers.30.mlp.gate", 7168, 384), (60, "model.layers.60.mlp.gate", 7168, 384), ] for li, pfx, in_f, out_f in test_cases: weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, 'weight' if 'gate' in pfx else pfx.split('.')[-1]) if 'gate' in pfx: # Gate weight weight, ws, ws2, isc = get_nvfp4_weight(all_w, '.'.join(pfx.split('.')[:-1]), 'gate') proj_name = 'gate' pfx_base = '.'.join(pfx.split('.')[:-1]) else: proj_name = pfx.split('.')[-1] pfx_base = '.'.join(pfx.split('.')[:-1]) weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx_base, proj_name) if weight is None: print(f"L{li} {proj_name}: weight not found, skipping") continue weight = weight.to(device) ws = ws.to(device) ws2 = ws2.to(device) if ws2 is not None else None isc = isc.to(device) if isc is not None else None actual_out = weight.shape[0] actual_in = weight.shape[1] * 2 # Create random input x = torch.randn(1, actual_in, dtype=torch.bfloat16, device=device) * 5.0 # PyTorch reference: dequant + F.linear w_ref = dequant_nvfp4(weight, ws, ws2, isc) ref_out = F.linear(x, w_ref) # Production: Nvfp4Linear lin = Nvfp4Linear(actual_in, actual_out, max_num_tokens=8192, device=device) lin.fp4 = [weight.to(device).view(torch.float4_e2m1fn_x2) if weight.dtype == torch.uint8 else weight.to(device)] lin.sf = [ws.to(device)] lin.gs = [1.0] lin.ws2 = [ws2.to(device) if ws2 is not None else None] isc_val = isc.float().item() if isc is not None else 1.0/(6.0*448.0) lin._activation_global_scale = isc_val lin.finalize_weights() prod_out = lin(x) # Compare cos = torch.nn.functional.cosine_similarity(prod_out.flatten().float(), ref_out.flatten().float(), dim=0).item() max_diff = (prod_out.float() - ref_out.float()).abs().max().item() prod_max = prod_out.abs().max().item() ref_max = ref_out.abs().max().item() print(f"L{li} {proj_name}: cos={cos:.6f} max_diff={max_diff:.4f} |prod|={prod_max:.4f} |ref|={ref_max:.4f} ratio={prod_max/(ref_max+1e-10):.4f}") print("\nDone.") if __name__ == "__main__": main()