#!/usr/bin/env python3 """Layer-by-layer comparison: production kernel vs PyTorch reference. This test loads both pipelines, runs the same input, and compares hidden states after each layer to find where the residual diverges. """ import os, sys, json, time, math, torch, torch.nn.functional as F from pathlib import Path CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4") DEVICE = "cuda:0" def main(): torch.manual_seed(42) # Load config with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) n_layers = cfg["num_hidden_layers"] H = cfg["hidden_size"] hd = cfg["head_dim"] n_hc = cfg.get("n_hc", 4) print(f"Model: {n_layers} layers, {H} hidden, {hd} head_dim, {n_hc} mHC streams") # --- Load production pipeline --- print("\nLoading production pipeline...") sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from single_shot_inference import DSV4Model prod_model = DSV4Model(CHECKPOINT_DIR, device=DEVICE) print("Production pipeline loaded.") # --- Load PyTorch reference pipeline --- print("\nLoading PyTorch reference pipeline...") from dsv4.reference.single_shot_PYTORCH_REFERENCE import mHCBlock, load_weights, forward_layer, rmsnorm all_w = load_weights(CHECKPOINT_DIR) print("Reference pipeline loaded.") # --- Same input for both --- # Use the DeepSeek prompt from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, trust_remote_code=True) prompt = "The capital of France is" ids = tokenizer.encode(prompt, add_special_tokens=False) # Add chat template user_token = 128803 asst_token = 128804 chat_ids = [user_token] + ids + [asst_token] print(f"Input: {len(chat_ids)} tokens: {chat_ids}") # --- Run production pipeline: prefill --- print("\n=== Production Pipeline: Prefill ===") prod_model.kv_cache.reset() prod_X = None prod_layer_states = [] # (X_l, X_mid, X_next) per layer # Process tokens one at a time (decode style) for ti, tid in enumerate(chat_ids): token_id = torch.tensor([[tid]], dtype=torch.int32, device=DEVICE) if ti == len(chat_ids) - 1: # Save layer states for the last token # We need to modify the production pipeline to capture per-layer states # For now, just run and capture the final output pass prod_model.decode_step(token_id, position_offset=ti) print("Production prefill done.") # --- Run reference pipeline: prefill --- print("\n=== Reference Pipeline: Prefill ===") # Initialize mHC state emb_w = all_w.get("model.embed_tokens.weight") emb_ref = torch.nn.Embedding(emb_w.shape[0], emb_w.shape[1]) emb_ref.weight.data = emb_w.bfloat16().to(DEVICE) ref_X = mHCBlock.init_state(emb_ref(torch.tensor(chat_ids, device=DEVICE)), n_hc=n_hc) # Build mHC blocks and norms for reference attn_mhcs, ffn_mhcs = [], [] attn_norms, ffn_norms = [], [] for li in range(n_layers): a_mhc = mHCBlock(H, n_hc, device=DEVICE) a_mhc.load(all_w[f"model.layers.{li}.attn_hc.fn"], all_w[f"model.layers.{li}.attn_hc.base"], all_w[f"model.layers.{li}.attn_hc.scale"]) attn_mhcs.append(a_mhc) f_mhc = mHCBlock(H, n_hc, device=DEVICE) f_mhc.load(all_w[f"model.layers.{li}.ffn_hc.fn"], all_w[f"model.layers.{li}.ffn_hc.base"], all_w[f"model.layers.{li}.ffn_hc.scale"]) ffn_mhcs.append(f_mhc) attn_norms.append(all_w[f"model.layers.{li}.input_layernorm.weight"].bfloat16().to(DEVICE)) ffn_norms.append(all_w[f"model.layers.{li}.post_attention_layernorm.weight"].bfloat16().to(DEVICE)) # Run reference layer by layer print("Running reference layer by layer...") ref_kv_cache = {} for li in range(n_layers): w = all_w X_before = ref_X.clone() ref_X = forward_layer(ref_X, w, li, cfg, None, None, attn_mhcs[li], ffn_mhcs[li], attn_norms[li], ffn_norms[li], ref_kv_cache, torch.arange(len(chat_ids), device=DEVICE), 0) x_max = ref_X.abs().max().item() if li % 10 == 0 or li >= 55: print(f" Ref L{li}: |X|={x_max:.1f}") print("Reference prefill done.") print(f" Final |X|: {ref_X.abs().max().item():.1f}") # Compare # We can't easily compare per-layer because the production pipeline # doesn't expose intermediate states. But we can compare the final # hidden state and the decoded token. print("\n=== Summary ===") print(f"Production final |X|: N/A (need to instrument)") print(f"Reference final |X|: {ref_X.abs().max().item():.1f}") if __name__ == "__main__": main()