- Move test_*.py → tests/integration/ - Move probe_*.py, dump_*.py → helpers/ - Move PERFORMANCE_AUDIT.md → docs/ - Move single_shot_PYTORCH_REFERENCE.py → dsv4/reference/ - Fix 3 import references in test_layer_comparison, test_mhc_comparison, test_compressor_position_bias - Add helpers/import_closure.py (dead-code detection tool)
125 lines
4.9 KiB
Python
125 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
"""Layer-by-layer comparison: production kernel vs PyTorch reference.
|
|
|
|
This test loads both pipelines, runs the same input, and compares
|
|
hidden states after each layer to find where the residual diverges.
|
|
"""
|
|
import os, sys, json, time, math, torch, torch.nn.functional as F
|
|
from pathlib import Path
|
|
|
|
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
|
DEVICE = "cuda:0"
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
|
|
# Load config
|
|
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
n_layers = cfg["num_hidden_layers"]
|
|
H = cfg["hidden_size"]
|
|
hd = cfg["head_dim"]
|
|
n_hc = cfg.get("n_hc", 4)
|
|
print(f"Model: {n_layers} layers, {H} hidden, {hd} head_dim, {n_hc} mHC streams")
|
|
|
|
# --- Load production pipeline ---
|
|
print("\nLoading production pipeline...")
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from single_shot_inference import DSV4Model
|
|
prod_model = DSV4Model(CHECKPOINT_DIR, device=DEVICE)
|
|
print("Production pipeline loaded.")
|
|
|
|
# --- Load PyTorch reference pipeline ---
|
|
print("\nLoading PyTorch reference pipeline...")
|
|
from dsv4.reference.single_shot_PYTORCH_REFERENCE import mHCBlock, load_weights, forward_layer, rmsnorm
|
|
all_w = load_weights(CHECKPOINT_DIR)
|
|
print("Reference pipeline loaded.")
|
|
|
|
# --- Same input for both ---
|
|
# Use the DeepSeek prompt
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, trust_remote_code=True)
|
|
prompt = "The capital of France is"
|
|
ids = tokenizer.encode(prompt, add_special_tokens=False)
|
|
# Add chat template
|
|
user_token = 128803
|
|
asst_token = 128804
|
|
chat_ids = [user_token] + ids + [asst_token]
|
|
print(f"Input: {len(chat_ids)} tokens: {chat_ids}")
|
|
|
|
# --- Run production pipeline: prefill ---
|
|
print("\n=== Production Pipeline: Prefill ===")
|
|
prod_model.kv_cache.reset()
|
|
prod_X = None
|
|
prod_layer_states = [] # (X_l, X_mid, X_next) per layer
|
|
|
|
# Process tokens one at a time (decode style)
|
|
for ti, tid in enumerate(chat_ids):
|
|
token_id = torch.tensor([[tid]], dtype=torch.int32, device=DEVICE)
|
|
if ti == len(chat_ids) - 1:
|
|
# Save layer states for the last token
|
|
# We need to modify the production pipeline to capture per-layer states
|
|
# For now, just run and capture the final output
|
|
pass
|
|
prod_model.decode_step(token_id, position_offset=ti)
|
|
|
|
print("Production prefill done.")
|
|
|
|
# --- Run reference pipeline: prefill ---
|
|
print("\n=== Reference Pipeline: Prefill ===")
|
|
# Initialize mHC state
|
|
emb_w = all_w.get("model.embed_tokens.weight")
|
|
emb_ref = torch.nn.Embedding(emb_w.shape[0], emb_w.shape[1])
|
|
emb_ref.weight.data = emb_w.bfloat16().to(DEVICE)
|
|
|
|
ref_X = mHCBlock.init_state(emb_ref(torch.tensor(chat_ids, device=DEVICE)), n_hc=n_hc)
|
|
|
|
# Build mHC blocks and norms for reference
|
|
attn_mhcs, ffn_mhcs = [], []
|
|
attn_norms, ffn_norms = [], []
|
|
for li in range(n_layers):
|
|
a_mhc = mHCBlock(H, n_hc, device=DEVICE)
|
|
a_mhc.load(all_w[f"model.layers.{li}.attn_hc.fn"],
|
|
all_w[f"model.layers.{li}.attn_hc.base"],
|
|
all_w[f"model.layers.{li}.attn_hc.scale"])
|
|
attn_mhcs.append(a_mhc)
|
|
|
|
f_mhc = mHCBlock(H, n_hc, device=DEVICE)
|
|
f_mhc.load(all_w[f"model.layers.{li}.ffn_hc.fn"],
|
|
all_w[f"model.layers.{li}.ffn_hc.base"],
|
|
all_w[f"model.layers.{li}.ffn_hc.scale"])
|
|
ffn_mhcs.append(f_mhc)
|
|
|
|
attn_norms.append(all_w[f"model.layers.{li}.input_layernorm.weight"].bfloat16().to(DEVICE))
|
|
ffn_norms.append(all_w[f"model.layers.{li}.post_attention_layernorm.weight"].bfloat16().to(DEVICE))
|
|
|
|
# Run reference layer by layer
|
|
print("Running reference layer by layer...")
|
|
ref_kv_cache = {}
|
|
for li in range(n_layers):
|
|
w = all_w
|
|
X_before = ref_X.clone()
|
|
ref_X = forward_layer(ref_X, w, li, cfg, None, None,
|
|
attn_mhcs[li], ffn_mhcs[li],
|
|
attn_norms[li], ffn_norms[li],
|
|
ref_kv_cache, torch.arange(len(chat_ids), device=DEVICE),
|
|
0)
|
|
x_max = ref_X.abs().max().item()
|
|
if li % 10 == 0 or li >= 55:
|
|
print(f" Ref L{li}: |X|={x_max:.1f}")
|
|
|
|
print("Reference prefill done.")
|
|
print(f" Final |X|: {ref_X.abs().max().item():.1f}")
|
|
|
|
# Compare
|
|
# We can't easily compare per-layer because the production pipeline
|
|
# doesn't expose intermediate states. But we can compare the final
|
|
# hidden state and the decoded token.
|
|
|
|
print("\n=== Summary ===")
|
|
print(f"Production final |X|: N/A (need to instrument)")
|
|
print(f"Reference final |X|: {ref_X.abs().max().item():.1f}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|