#!/usr/bin/env python3 """Layer-by-layer comparison between our single_shot_inference and HF reference. This test processes a single token through LAYER 0 using BOTH implementations and compares the intermediate values to identify the exact point of divergence. The "reference" implementation follows the HuggingFace DeepseekV4ForCausalLM source code exactly, but using our NVFP4 dequantization for the weights. Usage (on B200): source /root/dsv4-nvfp4-workspace/venv/bin/activate cd /root/dsv4-nvfp4-workspace/kernel python tests/layer_compare.py """ import os, sys, json, math import torch import torch.nn.functional as F from pathlib import Path sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel") from single_shot_inference import ( dequant_nvfp4_weight, nvfp4_linear, RMSNorm, apply_rope_partial, apply_inverse_rope, build_rope_cache, SimpleKVCache, mHCBlock ) CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" def main(): from safetensors.torch import load_file cdir = Path(CHECKPOINT_DIR) with open(cdir / "config.json") as f: cfg = json.load(f) with open(cdir / "model.safetensors.index.json") as f: wm = json.load(f)["weight_map"] H = cfg["hidden_size"] n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] rd = cfg.get("qk_rope_head_dim", 64) dc = cfg.get("q_lora_rank", 1536) n_hc = 4 device = "cuda:0" # Load layer 0 weights print("Loading layer 0 weights...") prefix = "model.layers.0." layer0_keys = [k for k in wm if k.startswith(prefix)] shards_needed = set(wm[k] for k in layer0_keys) all_w = {} for shard in shards_needed: data = load_file(str(cdir / shard)) for k in layer0_keys: if k in data: all_w[k] = data[k].to(device) # Load embedding embed_w = load_file(str(cdir / wm["model.embed_tokens.weight"]))["model.embed_tokens.weight"].to(device).bfloat16() from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained(str(cdir)) # Process token "The" tid = torch.tensor([tok.encode("The")[-1]], dtype=torch.long, device=device) pos = torch.tensor([0], dtype=torch.long, device=device) # Build RoPE cache with YaRN rope_params = cfg.get("rope_parameters", {}) rope_cos, rope_sin = build_rope_cache( 8192, rd, device, theta=rope_params.get("rope_theta", 10000.0), rope_type=rope_params.get("rope_type", "default"), rope_factor=rope_params.get("factor", 1.0), original_max_pos=rope_params.get("original_max_position_embeddings", 4096), beta_fast=rope_params.get("beta_fast", 32), beta_slow=rope_params.get("beta_slow", 1) ) # Embed emb = F.embedding(tid, embed_w) # (1, H) print(f"Embedding: |emb|={emb.abs().max():.4f}") # Init mHC state X = mHCBlock.init_state(emb, n_hc) # (1, 4, H) # Load mHC fn = all_w[f"{prefix}attn_hc.fn"] base = all_w[f"{prefix}attn_hc.base"] scale = all_w[f"{prefix}attn_hc.scale"] attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device) n = n_hc attn_mhc.load_weights( W_pre=fn[0:n].to(device, dtype=torch.float32), W_post=fn[n:2*n].to(device, dtype=torch.float32), W_comb=fn[2*n:].to(device, dtype=torch.float32), S_pre=base[0:n].reshape(1, n).to(device, dtype=torch.bfloat16), S_post=base[n:2*n].reshape(n, 1).to(device, dtype=torch.bfloat16), S_comb=base[2*n:].reshape(n, n).to(device, dtype=torch.bfloat16), alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item(), ) # === OUR IMPLEMENTATION (single_shot_inference) === print("\n=== OUR IMPLEMENTATION ===") # mHC pre_block x_in, ctx = attn_mhc.pre_block(X) print(f"x_in: |x_in|={x_in.abs().max():.4f} mean={x_in.float().abs().mean():.6f}") # RMSNorm norm = RMSNorm(H, device=device) norm.weight = all_w[f"{prefix}input_layernorm.weight"].to(device, dtype=torch.float32) x_norm = norm.forward(x_in) print(f"x_norm: |x|={x_norm.abs().max():.4f} mean={x_norm.float().abs().mean():.6f}") # Q projection: q_a → q_a_norm → q_b → q_b_norm c_Q = nvfp4_linear(x_norm, all_w[f"{prefix}self_attn.q_a_proj.weight"], all_w[f"{prefix}self_attn.q_a_proj.weight_scale"], all_w[f"{prefix}self_attn.q_a_proj.weight_scale_2"]) # q_a_norm q_norm_w = all_w.get(f"{prefix}self_attn.q_a_norm.weight") if q_norm_w is not None: c_Q_f = c_Q.float() c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16() print(f"c_Q: |c_Q|={c_Q.abs().max():.4f} mean={c_Q.float().abs().mean():.6f}") q = nvfp4_linear(c_Q, all_w[f"{prefix}self_attn.q_b_proj.weight"], all_w[f"{prefix}self_attn.q_b_proj.weight_scale"], all_w[f"{prefix}self_attn.q_b_proj.weight_scale_2"]) # q_b_norm (unweighted RMSNorm) q_f = q.float() q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() q = (q_f * q_rms).bfloat16() q_heads = q.reshape(1, n_h, hd) print(f"q_heads: |q|={q_heads.abs().max():.4f} mean={q_heads.float().abs().mean():.6f}") # KV projection kv = nvfp4_linear(x_norm, all_w[f"{prefix}self_attn.kv_proj.weight"], all_w[f"{prefix}self_attn.kv_proj.weight_scale"], all_w[f"{prefix}self_attn.kv_proj.weight_scale_2"]) # kv_norm kv_norm_w = all_w.get(f"{prefix}self_attn.kv_norm.weight") if kv_norm_w is not None: kv_f = kv.float() kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16() kv_new = kv.reshape(1, 1, hd) print(f"kv: |kv|={kv_new.abs().max():.4f} mean={kv_new.float().abs().mean():.6f}") # Apply RoPE q_heads = apply_rope_partial(q_heads, pos, rope_cos, rope_sin, hd, rd) kv_new = apply_rope_partial(kv_new, pos, rope_cos, rope_sin, hd, rd) print(f"After RoPE: |q|={q_heads.abs().max():.4f} |kv|={kv_new.abs().max():.4f}") # Attention (single token, trivially 1.0) q_in = q_heads.permute(1, 0, 2) # (n_h, 1, hd) k_in = kv_new.permute(1, 0, 2) # (1, 1, hd) k_exp = k_in.expand(n_h, -1, -1) v_exp = k_exp.clone() # K=V in DSV4 attn_out = F.scaled_dot_product_attention(q_in, k_exp, v_exp, scale=1.0/math.sqrt(hd)) attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd) print(f"attn_out: |o|={attn_out.abs().max():.4f} mean={attn_out.float().abs().mean():.6f}") # Inverse RoPE attn_out = apply_inverse_rope(attn_out, pos, rope_cos, rope_sin, hd, rd) print(f"After inverse RoPE: |o|={attn_out.abs().max():.4f}") # Output projection: wo_a (grouped BMM) + wo_b o_groups = cfg.get("num_output_groups", 16) o_rank = cfg.get("output_group_dim", 1024) heads_per_group = n_h // o_groups group_input_dim = heads_per_group * hd attn_flat = attn_out.reshape(1, n_h * hd) attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) oa_w = all_w[f"{prefix}self_attn.o_a_proj.weight"].bfloat16() oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) attn_bmm = attn_grouped.permute(1, 0, 2) grouped_out = torch.bmm(attn_bmm, oa_3d.transpose(1, 2)) grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) print(f"grouped_out: |o|={grouped_flat.abs().max():.4f}") F_attn = nvfp4_linear(grouped_flat, all_w[f"{prefix}self_attn.o_b_proj.weight"], all_w[f"{prefix}self_attn.o_b_proj.weight_scale"], all_w[f"{prefix}self_attn.o_b_proj.weight_scale_2"]) print(f"F_attn: |F|={F_attn.abs().max():.4f} mean={F_attn.float().abs().mean():.6f}") # mHC post_block X_mid = attn_mhc.post_block(X, F_attn, ctx) print(f"X_mid: |X|={X_mid.abs().max():.4f}") print("\nLayer 0 attention sub-block complete.") if __name__ == "__main__": main()