Files
nvfp4-megamoe-kernel/tests/unit/test_layer_comparison.py
biondizzle 8de47e26ce Cleanup Step 1: Move root-level files to proper directories
- Move test_*.py → tests/integration/
- Move probe_*.py, dump_*.py → helpers/
- Move PERFORMANCE_AUDIT.md → docs/
- Move single_shot_PYTORCH_REFERENCE.py → dsv4/reference/
- Fix 3 import references in test_layer_comparison, test_mhc_comparison, test_compressor_position_bias
- Add helpers/import_closure.py (dead-code detection tool)
2026-06-02 19:24:39 +00:00

125 lines
4.9 KiB
Python

#!/usr/bin/env python3
"""Layer-by-layer comparison: production kernel vs PyTorch reference.
This test loads both pipelines, runs the same input, and compares
hidden states after each layer to find where the residual diverges.
"""
import os, sys, json, time, math, torch, torch.nn.functional as F
from pathlib import Path
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
DEVICE = "cuda:0"
def main():
torch.manual_seed(42)
# Load config
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
n_layers = cfg["num_hidden_layers"]
H = cfg["hidden_size"]
hd = cfg["head_dim"]
n_hc = cfg.get("n_hc", 4)
print(f"Model: {n_layers} layers, {H} hidden, {hd} head_dim, {n_hc} mHC streams")
# --- Load production pipeline ---
print("\nLoading production pipeline...")
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from single_shot_inference import DSV4Model
prod_model = DSV4Model(CHECKPOINT_DIR, device=DEVICE)
print("Production pipeline loaded.")
# --- Load PyTorch reference pipeline ---
print("\nLoading PyTorch reference pipeline...")
from dsv4.reference.single_shot_PYTORCH_REFERENCE import mHCBlock, load_weights, forward_layer, rmsnorm
all_w = load_weights(CHECKPOINT_DIR)
print("Reference pipeline loaded.")
# --- Same input for both ---
# Use the DeepSeek prompt
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, trust_remote_code=True)
prompt = "The capital of France is"
ids = tokenizer.encode(prompt, add_special_tokens=False)
# Add chat template
user_token = 128803
asst_token = 128804
chat_ids = [user_token] + ids + [asst_token]
print(f"Input: {len(chat_ids)} tokens: {chat_ids}")
# --- Run production pipeline: prefill ---
print("\n=== Production Pipeline: Prefill ===")
prod_model.kv_cache.reset()
prod_X = None
prod_layer_states = [] # (X_l, X_mid, X_next) per layer
# Process tokens one at a time (decode style)
for ti, tid in enumerate(chat_ids):
token_id = torch.tensor([[tid]], dtype=torch.int32, device=DEVICE)
if ti == len(chat_ids) - 1:
# Save layer states for the last token
# We need to modify the production pipeline to capture per-layer states
# For now, just run and capture the final output
pass
prod_model.decode_step(token_id, position_offset=ti)
print("Production prefill done.")
# --- Run reference pipeline: prefill ---
print("\n=== Reference Pipeline: Prefill ===")
# Initialize mHC state
emb_w = all_w.get("model.embed_tokens.weight")
emb_ref = torch.nn.Embedding(emb_w.shape[0], emb_w.shape[1])
emb_ref.weight.data = emb_w.bfloat16().to(DEVICE)
ref_X = mHCBlock.init_state(emb_ref(torch.tensor(chat_ids, device=DEVICE)), n_hc=n_hc)
# Build mHC blocks and norms for reference
attn_mhcs, ffn_mhcs = [], []
attn_norms, ffn_norms = [], []
for li in range(n_layers):
a_mhc = mHCBlock(H, n_hc, device=DEVICE)
a_mhc.load(all_w[f"model.layers.{li}.attn_hc.fn"],
all_w[f"model.layers.{li}.attn_hc.base"],
all_w[f"model.layers.{li}.attn_hc.scale"])
attn_mhcs.append(a_mhc)
f_mhc = mHCBlock(H, n_hc, device=DEVICE)
f_mhc.load(all_w[f"model.layers.{li}.ffn_hc.fn"],
all_w[f"model.layers.{li}.ffn_hc.base"],
all_w[f"model.layers.{li}.ffn_hc.scale"])
ffn_mhcs.append(f_mhc)
attn_norms.append(all_w[f"model.layers.{li}.input_layernorm.weight"].bfloat16().to(DEVICE))
ffn_norms.append(all_w[f"model.layers.{li}.post_attention_layernorm.weight"].bfloat16().to(DEVICE))
# Run reference layer by layer
print("Running reference layer by layer...")
ref_kv_cache = {}
for li in range(n_layers):
w = all_w
X_before = ref_X.clone()
ref_X = forward_layer(ref_X, w, li, cfg, None, None,
attn_mhcs[li], ffn_mhcs[li],
attn_norms[li], ffn_norms[li],
ref_kv_cache, torch.arange(len(chat_ids), device=DEVICE),
0)
x_max = ref_X.abs().max().item()
if li % 10 == 0 or li >= 55:
print(f" Ref L{li}: |X|={x_max:.1f}")
print("Reference prefill done.")
print(f" Final |X|: {ref_X.abs().max().item():.1f}")
# Compare
# We can't easily compare per-layer because the production pipeline
# doesn't expose intermediate states. But we can compare the final
# hidden state and the decoded token.
print("\n=== Summary ===")
print(f"Production final |X|: N/A (need to instrument)")
print(f"Reference final |X|: {ref_X.abs().max().item():.1f}")
if __name__ == "__main__":
main()