#!/usr/bin/env python3 """Per-layer validation: compare forward_layer output with a step-by-step reference. This test takes a known embedding, processes it through a SINGLE layer, and compares the output at each intermediate step between the production forward_layer function and a standalone PyTorch reference. Usage (on B200): source /root/dsv4-nvfp4-workspace/venv/bin/activate cd /root/dsv4-nvfp4-workspace/kernel python3 tests/validate_layer.py --layer 0 """ import os, sys, json, math, argparse, torch from pathlib import Path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" # ===================================================================== # NVFP4 dequantization # ===================================================================== FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) def dequant_nvfp4(weight, weight_scale, weight_scale_2): out_dim = weight.shape[0] in_packed = weight.shape[1] in_features = in_packed * 2 low = (weight & 0x0F).to(torch.int8) high = (weight >> 4).to(torch.int8) low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long() high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long() lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0) w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) scale_f = weight_scale.float() * weight_scale_2.float() scale_expanded = scale_f.repeat_interleave(16, dim=1) return (w_f * scale_expanded).bfloat16() def nvfp4_linear(x, weight, weight_scale, weight_scale_2): w = dequant_nvfp4(weight, weight_scale, weight_scale_2) return torch.nn.functional.linear(x, w) def rmsnorm(x, weight, eps=1e-6): """Weighted RMSNorm matching HF DeepseekV4RMSNorm.""" x_f = x.float() rms_inv = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() return (x_f * rms_inv * weight.float()).to(torch.bfloat16) def unweighted_rmsnorm(x, eps=1e-6): """Unweighted RMSNorm matching HF DeepseekV4UnweightedRMSNorm.""" x_f = x.float() rms_inv = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() return (x_f * rms_inv).to(torch.bfloat16) def sinkhorn(logits, t_max=20, eps=1e-6): """Sinkhorn-Knopp from softmax, matching HF.""" M = torch.softmax(logits, dim=-1) + eps M = M / (M.sum(dim=-2, keepdim=True) + eps) for _ in range(t_max - 1): M = M / (M.sum(dim=-1, keepdim=True) + eps) M = M / (M.sum(dim=-2, keepdim=True) + eps) return M def build_rope_cache(max_pos, rope_dim, device, theta=10000.0, rope_type="default", rope_factor=1.0, original_max_pos=4096, beta_fast=32, beta_slow=1): half = rope_dim // 2 freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) if rope_type == "yarn" and rope_factor > 1.0: new_freqs = [] for freq in freqs: wavelen = 2 * math.pi / freq if wavelen < original_max_pos / (beta_fast * 2.0): new_freqs.append(freq) elif wavelen > original_max_pos / (beta_slow * 2.0): new_freqs.append(freq / rope_factor) else: smooth = (original_max_pos / (wavelen * beta_slow) - rope_factor) / ( rope_factor * (beta_fast / beta_slow - 1)) new_freqs.append((1 - smooth) * freq / rope_factor + smooth * freq) freqs = torch.tensor(new_freqs, dtype=torch.float32) angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) return torch.cos(angles).to(device), torch.sin(angles).to(device) def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim): T, n_h, hd = x.shape nope = hd - rope_dim cos = cos_cache[positions].unsqueeze(1) sin = sin_cache[positions].unsqueeze(1) x_rope = x[:, :, nope:].float() x_even = x_rope[..., 0::2] x_odd = x_rope[..., 1::2] rot_even = x_even * cos - x_odd * sin rot_odd = x_even * sin + x_odd * cos result = x.clone() rope_out = torch.empty_like(x_rope) rope_out[..., 0::2] = rot_even rope_out[..., 1::2] = rot_odd result[:, :, nope:] = rope_out.to(torch.bfloat16) return result def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): T, n_h, hd = o.shape nope = hd - rope_dim cos = cos_cache[positions].unsqueeze(1) sin = sin_cache[positions].unsqueeze(1) o_rope = o[:, :, nope:].float() o_even = o_rope[..., 0::2] o_odd = o_rope[..., 1::2] inv_even = o_even * cos + o_odd * sin inv_odd = -o_even * sin + o_odd * cos result = o.clone() rope_out = torch.empty_like(o_rope) rope_out[..., 0::2] = inv_even rope_out[..., 1::2] = inv_odd result[:, :, nope:] = rope_out.to(torch.bfloat16) return result def validate_layer(li, all_weights, cfg, device='cuda:0'): """Validate a single layer by running both forward_layer and a step-by-step reference.""" from single_shot_inference import forward_layer, mHCBlock, RMSNorm, SimpleKVCache n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] rd = cfg.get("qk_rope_head_dim", 64) H = cfg["hidden_size"] o_groups = cfg.get("o_groups", 16) o_rank = cfg.get("o_group_dim", 1024) n_hc = 4 pre = f"model.layers.{li}.self_attn" # Get weights w = all_weights # Already filtered and on device # Build RoPE caches rope_params = cfg.get("rope_parameters", {}) rope_type = rope_params.get("rope_type", "yarn") rope_factor = rope_params.get("factor", 16.0) rope_theta = rope_params.get("rope_theta", cfg.get("rope_theta", 10000.0)) original_max_pos = rope_params.get("original_max_position_embeddings", 4096) beta_fast = rope_params.get("beta_fast", 32) beta_slow = rope_params.get("beta_slow", 1) rope_cos, rope_sin = build_rope_cache( 8192, rd, device, theta=rope_theta, rope_type=rope_type, rope_factor=rope_factor, original_max_pos=original_max_pos, beta_fast=beta_fast, beta_slow=beta_slow ) # Create mHC blocks attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device) ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device) fn_key = f"model.layers.{li}.attn_hc.fn" base_key = f"model.layers.{li}.attn_hc.base" scale_key = f"model.layers.{li}.attn_hc.scale" if fn_key in w and base_key in w and scale_key in w: attn_mhc.load_from_checkpoint(w[fn_key], w[base_key], w[scale_key]) fn_key = f"model.layers.{li}.ffn_hc.fn" base_key = f"model.layers.{li}.ffn_hc.base" scale_key = f"model.layers.{li}.ffn_hc.scale" if fn_key in w and base_key in w and scale_key in w: ffn_mhc.load_from_checkpoint(w[fn_key], w[base_key], w[scale_key]) attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device) an_key = f"model.layers.{li}.input_layernorm.weight" if an_key in w: attn_norm.weight = w[an_key].to(device=device, dtype=torch.float32) ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device) fn_key = f"model.layers.{li}.post_attention_layernorm.weight" if fn_key in w: ffn_norm.weight = w[fn_key].to(device=device, dtype=torch.float32) kv_cache = SimpleKVCache(head_dim=hd, max_seq=8192, device=device) # Create input: random embedding torch.manual_seed(42) X_l = torch.randn(1, n_hc, H, dtype=torch.bfloat16, device=device) * 0.5 positions = torch.tensor([0], dtype=torch.long, device=device) token_id = torch.tensor([671], dtype=torch.long, device=device) # "The" # Run forward_layer (production code) X_prod = forward_layer(X_l.clone(), w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc, attn_norm, ffn_norm, kv_cache, token_id, positions) print(f"Production: |X_next|={X_prod.abs().max().item():.4f}") print(f" Stream norms: {[X_prod[0,s,:].float().norm().item() for s in range(4)]}") # ============================================================ # Step-by-step reference (same math, no wrappers) # ============================================================ X_l_ref = X_l.clone() # --- mHC pre_block (attention) --- fn = w[f"model.layers.{li}.attn_hc.fn"].float().to(device) base = w[f"model.layers.{li}.attn_hc.base"].float().to(device) scale = w[f"model.layers.{li}.attn_hc.scale"].float().to(device) # Unweighted RMSNorm on flattened residual X_flat = X_l_ref.reshape(1, n_hc * H).float() rms_inv = X_flat.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() flat = (X_flat * rms_inv).to(torch.bfloat16) # F.linear + split [pre(4), post(4), comb(16)] proj = torch.nn.functional.linear(flat.float(), fn).float() pre_w, post_w, comb_w = proj.split([n_hc, n_hc, n_hc * n_hc], dim=-1) pre_b, post_b, comb_b = base.split([n_hc, n_hc, n_hc * n_hc]) pre_s, post_s, comb_s = scale.unbind(0) A_l = torch.sigmoid(pre_w * pre_s + pre_b) + 1e-6 C_l = 2.0 * torch.sigmoid(post_w * post_s + post_b) B_l = sinkhorn((comb_w * comb_s + comb_b).reshape(1, n_hc, n_hc)) x_in = (A_l.unsqueeze(-1) * X_l_ref.float()).sum(dim=1).to(torch.bfloat16) print(f"\n mHC A_l: {A_l[0].tolist()}") print(f" mHC C_l: {C_l[0].tolist()}") print(f" B row sums: {B_l[0].sum(-1).tolist()}") # --- RMSNorm --- x_normed = rmsnorm(x_in, w[f"model.layers.{li}.input_layernorm.weight"].to(device)) # --- Q projection --- c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) c_Q = rmsnorm(c_Q, w[f"{pre}.q_a_norm.weight"].to(device)) q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) q = unweighted_rmsnorm(q) # --- KV projection --- kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) kv = rmsnorm(kv, w[f"{pre}.kv_norm.weight"].to(device)) q_heads = q.reshape(1, n_h, hd) kv_new = kv.reshape(1, 1, hd) # RoPE q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd) kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd) # Attention (single token → self-attention is identity) attn_out = kv_new.expand(1, n_h, hd) # Inverse RoPE attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd) # Output projection attn_flat = attn_out.reshape(1, n_h * hd) attn_grouped = attn_flat.reshape(1, o_groups, (n_h // o_groups) * hd) oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16().to(device) oa_3d = oa_w.reshape(o_groups, o_rank, (n_h // o_groups) * hd) grouped_out = torch.bmm(attn_grouped.permute(1, 0, 2), oa_3d.transpose(1, 2)) grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"]) print(f" |F_attn|={F_attn.abs().max().item():.4f}") # mHC post_block: X_mid = C_l * F_attn + B_l.T @ X_l BX = torch.bmm(B_l.transpose(-1, -2), X_l_ref.float()) CF = C_l.unsqueeze(-1) * F_attn.unsqueeze(1) X_mid = (CF.float() + BX).to(torch.bfloat16) print(f" |X_mid|={X_mid.abs().max().item():.4f}") print(f" Stream norms (mid): {[X_mid[0,s,:].float().norm().item() for s in range(4)]}") # Compare X_mid with production X_prod_mid_stream0 = X_prod[0, 0, :].float() X_ref_mid_stream0 = X_mid[0, 0, :].float() cos_sim = torch.nn.functional.cosine_similarity(X_prod_mid_stream0.unsqueeze(0), X_ref_mid_stream0.unsqueeze(0)).item() max_diff = (X_prod - X_mid).abs().max().item() print(f"\n Stream 0 cosine similarity: {cos_sim:.6f}") print(f" Max diff: {max_diff:.6f}") if cos_sim < 0.99: print(" ⚠️ MISMATCH! Production and reference differ significantly!") else: print(" ✅ Match!") return X_mid def main(): from safetensors.torch import load_file p = argparse.ArgumentParser() p.add_argument('--layer', type=int, default=0) p.add_argument('--device', type=str, default='cuda:0') args = p.parse_args() with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) print(f"Loading weights for layer {args.layer}...") cdir = Path(CHECKPOINT_DIR) index_path = cdir / "model.safetensors.index.json" weight_map = {} if index_path.exists(): with open(index_path) as f: weight_map = json.load(f).get("weight_map", {}) # Find which shards contain our layer li = args.layer prefix = f"model.layers.{li}." needed_shards = set() for key, shard in weight_map.items(): if key.startswith(prefix) or key in ["model.embed_tokens.weight", "model.norm.weight"]: needed_shards.add(shard) all_w = {} for shard_name in sorted(needed_shards): if not (cdir / shard_name).exists(): continue data = load_file(str(cdir / shard_name)) for k, v in data.items(): if k.startswith(prefix) or k in ["model.embed_tokens.weight"]: all_w[k] = v.to(device=args.device, non_blocking=True) print(f" {len(all_w)} weights loaded") validate_layer(li, all_w, cfg, device=args.device) if __name__ == "__main__": main()