diff --git a/tests/layer_compare.py b/tests/layer_compare.py new file mode 100644 index 00000000..a47bba9d --- /dev/null +++ b/tests/layer_compare.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +"""Layer-by-layer comparison between our single_shot_inference and HF reference. + +This test processes a single token through LAYER 0 using BOTH implementations +and compares the intermediate values to identify the exact point of divergence. + +The "reference" implementation follows the HuggingFace DeepseekV4ForCausalLM +source code exactly, but using our NVFP4 dequantization for the weights. + +Usage (on B200): + source /root/dsv4-nvfp4-workspace/venv/bin/activate + cd /root/dsv4-nvfp4-workspace/kernel + python tests/layer_compare.py +""" +import os, sys, json, math +import torch +import torch.nn.functional as F +from pathlib import Path + +sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel") +from single_shot_inference import ( + dequant_nvfp4_weight, nvfp4_linear, RMSNorm, + apply_rope_partial, apply_inverse_rope, build_rope_cache, + SimpleKVCache, mHCBlock +) + +CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" + +def main(): + from safetensors.torch import load_file + cdir = Path(CHECKPOINT_DIR) + with open(cdir / "config.json") as f: + cfg = json.load(f) + with open(cdir / "model.safetensors.index.json") as f: + wm = json.load(f)["weight_map"] + + H = cfg["hidden_size"] + n_h = cfg["num_attention_heads"] + hd = cfg["head_dim"] + rd = cfg.get("qk_rope_head_dim", 64) + dc = cfg.get("q_lora_rank", 1536) + n_hc = 4 + device = "cuda:0" + + # Load layer 0 weights + print("Loading layer 0 weights...") + prefix = "model.layers.0." + layer0_keys = [k for k in wm if k.startswith(prefix)] + shards_needed = set(wm[k] for k in layer0_keys) + all_w = {} + for shard in shards_needed: + data = load_file(str(cdir / shard)) + for k in layer0_keys: + if k in data: + all_w[k] = data[k].to(device) + + # Load embedding + embed_w = load_file(str(cdir / wm["model.embed_tokens.weight"]))["model.embed_tokens.weight"].to(device).bfloat16() + from transformers import AutoTokenizer + tok = AutoTokenizer.from_pretrained(str(cdir)) + + # Process token "The" + tid = torch.tensor([tok.encode("The")[-1]], dtype=torch.long, device=device) + pos = torch.tensor([0], dtype=torch.long, device=device) + + # Build RoPE cache with YaRN + rope_params = cfg.get("rope_parameters", {}) + rope_cos, rope_sin = build_rope_cache( + 8192, rd, device, theta=rope_params.get("rope_theta", 10000.0), + rope_type=rope_params.get("rope_type", "default"), + rope_factor=rope_params.get("factor", 1.0), + original_max_pos=rope_params.get("original_max_position_embeddings", 4096), + beta_fast=rope_params.get("beta_fast", 32), + beta_slow=rope_params.get("beta_slow", 1) + ) + + # Embed + emb = F.embedding(tid, embed_w) # (1, H) + print(f"Embedding: |emb|={emb.abs().max():.4f}") + + # Init mHC state + X = mHCBlock.init_state(emb, n_hc) # (1, 4, H) + + # Load mHC + fn = all_w[f"{prefix}attn_hc.fn"] + base = all_w[f"{prefix}attn_hc.base"] + scale = all_w[f"{prefix}attn_hc.scale"] + attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device) + n = n_hc + attn_mhc.load_weights( + W_pre=fn[0:n].to(device, dtype=torch.float32), + W_res=fn[n:n+n*n].to(device, dtype=torch.float32), + W_post=fn[n+n*n:].to(device, dtype=torch.float32), + S_pre=base[0:n].reshape(1, n).to(device, dtype=torch.bfloat16), + S_res=base[n:n+n*n].reshape(n, n).to(device, dtype=torch.bfloat16), + S_post=base[n+n*n:].reshape(n, 1).to(device, dtype=torch.bfloat16), + alpha_pre=scale[0].item(), + alpha_res=scale[1].item(), + alpha_post=scale[2].item(), + ) + + # === OUR IMPLEMENTATION (single_shot_inference) === + print("\n=== OUR IMPLEMENTATION ===") + + # mHC pre_block + x_in, ctx = attn_mhc.pre_block(X) + print(f"x_in: |x_in|={x_in.abs().max():.4f} mean={x_in.float().abs().mean():.6f}") + + # RMSNorm + norm = RMSNorm(H, device=device) + norm.weight = all_w[f"{prefix}input_layernorm.weight"].to(device, dtype=torch.float32) + x_norm = norm.forward(x_in) + print(f"x_norm: |x|={x_norm.abs().max():.4f} mean={x_norm.float().abs().mean():.6f}") + + # Q projection: q_a → q_a_norm → q_b → q_b_norm + c_Q = nvfp4_linear(x_norm, all_w[f"{prefix}self_attn.q_a_proj.weight"], + all_w[f"{prefix}self_attn.q_a_proj.weight_scale"], + all_w[f"{prefix}self_attn.q_a_proj.weight_scale_2"]) + # q_a_norm + q_norm_w = all_w.get(f"{prefix}self_attn.q_a_norm.weight") + if q_norm_w is not None: + c_Q_f = c_Q.float() + c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16() + print(f"c_Q: |c_Q|={c_Q.abs().max():.4f} mean={c_Q.float().abs().mean():.6f}") + + q = nvfp4_linear(c_Q, all_w[f"{prefix}self_attn.q_b_proj.weight"], + all_w[f"{prefix}self_attn.q_b_proj.weight_scale"], + all_w[f"{prefix}self_attn.q_b_proj.weight_scale_2"]) + # q_b_norm (unweighted RMSNorm) + q_f = q.float() + q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + q = (q_f * q_rms).bfloat16() + q_heads = q.reshape(1, n_h, hd) + print(f"q_heads: |q|={q_heads.abs().max():.4f} mean={q_heads.float().abs().mean():.6f}") + + # KV projection + kv = nvfp4_linear(x_norm, all_w[f"{prefix}self_attn.kv_proj.weight"], + all_w[f"{prefix}self_attn.kv_proj.weight_scale"], + all_w[f"{prefix}self_attn.kv_proj.weight_scale_2"]) + # kv_norm + kv_norm_w = all_w.get(f"{prefix}self_attn.kv_norm.weight") + if kv_norm_w is not None: + kv_f = kv.float() + kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16() + kv_new = kv.reshape(1, 1, hd) + print(f"kv: |kv|={kv_new.abs().max():.4f} mean={kv_new.float().abs().mean():.6f}") + + # Apply RoPE + q_heads = apply_rope_partial(q_heads, pos, rope_cos, rope_sin, hd, rd) + kv_new = apply_rope_partial(kv_new, pos, rope_cos, rope_sin, hd, rd) + print(f"After RoPE: |q|={q_heads.abs().max():.4f} |kv|={kv_new.abs().max():.4f}") + + # Attention (single token, trivially 1.0) + q_in = q_heads.permute(1, 0, 2) # (n_h, 1, hd) + k_in = kv_new.permute(1, 0, 2) # (1, 1, hd) + k_exp = k_in.expand(n_h, -1, -1) + v_exp = k_exp.clone() # K=V in DSV4 + attn_out = F.scaled_dot_product_attention(q_in, k_exp, v_exp, scale=1.0/math.sqrt(hd)) + attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd) + print(f"attn_out: |o|={attn_out.abs().max():.4f} mean={attn_out.float().abs().mean():.6f}") + + # Inverse RoPE + attn_out = apply_inverse_rope(attn_out, pos, rope_cos, rope_sin, hd, rd) + print(f"After inverse RoPE: |o|={attn_out.abs().max():.4f}") + + # Output projection: wo_a (grouped BMM) + wo_b + o_groups = cfg.get("num_output_groups", 16) + o_rank = cfg.get("output_group_dim", 1024) + heads_per_group = n_h // o_groups + group_input_dim = heads_per_group * hd + + attn_flat = attn_out.reshape(1, n_h * hd) + attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) + oa_w = all_w[f"{prefix}self_attn.o_a_proj.weight"].bfloat16() + oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) + attn_bmm = attn_grouped.permute(1, 0, 2) + grouped_out = torch.bmm(attn_bmm, oa_3d.transpose(1, 2)) + grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) + print(f"grouped_out: |o|={grouped_flat.abs().max():.4f}") + + F_attn = nvfp4_linear(grouped_flat, + all_w[f"{prefix}self_attn.o_b_proj.weight"], + all_w[f"{prefix}self_attn.o_b_proj.weight_scale"], + all_w[f"{prefix}self_attn.o_b_proj.weight_scale_2"]) + print(f"F_attn: |F|={F_attn.abs().max():.4f} mean={F_attn.float().abs().mean():.6f}") + + # mHC post_block + X_mid = attn_mhc.post_block(X, F_attn, ctx) + print(f"X_mid: |X|={X_mid.abs().max():.4f}") + + print("\nLayer 0 attention sub-block complete.") + + +if __name__ == "__main__": + main()