diff --git a/tests/validate_layer.py b/tests/validate_layer.py new file mode 100644 index 00000000..2b5ea33e --- /dev/null +++ b/tests/validate_layer.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +"""Per-layer validation: compare forward_layer output with a step-by-step reference. + +This test takes a known embedding, processes it through a SINGLE layer, +and compares the output at each intermediate step between the production +forward_layer function and a standalone PyTorch reference. + +Usage (on B200): + source /root/dsv4-nvfp4-workspace/venv/bin/activate + cd /root/dsv4-nvfp4-workspace/kernel + python3 tests/validate_layer.py --layer 0 +""" +import os, sys, json, math, argparse, torch +from pathlib import Path + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" + +# ===================================================================== +# NVFP4 dequantization +# ===================================================================== +FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) + +def dequant_nvfp4(weight, weight_scale, weight_scale_2): + out_dim = weight.shape[0] + in_packed = weight.shape[1] + in_features = in_packed * 2 + low = (weight & 0x0F).to(torch.int8) + high = (weight >> 4).to(torch.int8) + low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long() + high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long() + lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) + low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) + high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0) + w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) + scale_f = weight_scale.float() * weight_scale_2.float() + scale_expanded = scale_f.repeat_interleave(16, dim=1) + return (w_f * scale_expanded).bfloat16() + +def nvfp4_linear(x, weight, weight_scale, weight_scale_2): + w = dequant_nvfp4(weight, weight_scale, weight_scale_2) + return torch.nn.functional.linear(x, w) + +def rmsnorm(x, weight, eps=1e-6): + """Weighted RMSNorm matching HF DeepseekV4RMSNorm.""" + x_f = x.float() + rms_inv = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() + return (x_f * rms_inv * weight.float()).to(torch.bfloat16) + +def unweighted_rmsnorm(x, eps=1e-6): + """Unweighted RMSNorm matching HF DeepseekV4UnweightedRMSNorm.""" + x_f = x.float() + rms_inv = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() + return (x_f * rms_inv).to(torch.bfloat16) + + +def sinkhorn(logits, t_max=20, eps=1e-6): + """Sinkhorn-Knopp from softmax, matching HF.""" + M = torch.softmax(logits, dim=-1) + eps + M = M / (M.sum(dim=-2, keepdim=True) + eps) + for _ in range(t_max - 1): + M = M / (M.sum(dim=-1, keepdim=True) + eps) + M = M / (M.sum(dim=-2, keepdim=True) + eps) + return M + + +def build_rope_cache(max_pos, rope_dim, device, theta=10000.0, + rope_type="default", rope_factor=1.0, + original_max_pos=4096, beta_fast=32, beta_slow=1): + half = rope_dim // 2 + freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) + if rope_type == "yarn" and rope_factor > 1.0: + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < original_max_pos / (beta_fast * 2.0): + new_freqs.append(freq) + elif wavelen > original_max_pos / (beta_slow * 2.0): + new_freqs.append(freq / rope_factor) + else: + smooth = (original_max_pos / (wavelen * beta_slow) - rope_factor) / ( + rope_factor * (beta_fast / beta_slow - 1)) + new_freqs.append((1 - smooth) * freq / rope_factor + smooth * freq) + freqs = torch.tensor(new_freqs, dtype=torch.float32) + angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) + return torch.cos(angles).to(device), torch.sin(angles).to(device) + + +def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim): + T, n_h, hd = x.shape + nope = hd - rope_dim + cos = cos_cache[positions].unsqueeze(1) + sin = sin_cache[positions].unsqueeze(1) + x_rope = x[:, :, nope:].float() + x_even = x_rope[..., 0::2] + x_odd = x_rope[..., 1::2] + rot_even = x_even * cos - x_odd * sin + rot_odd = x_even * sin + x_odd * cos + result = x.clone() + rope_out = torch.empty_like(x_rope) + rope_out[..., 0::2] = rot_even + rope_out[..., 1::2] = rot_odd + result[:, :, nope:] = rope_out.to(torch.bfloat16) + return result + + +def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): + T, n_h, hd = o.shape + nope = hd - rope_dim + cos = cos_cache[positions].unsqueeze(1) + sin = sin_cache[positions].unsqueeze(1) + o_rope = o[:, :, nope:].float() + o_even = o_rope[..., 0::2] + o_odd = o_rope[..., 1::2] + inv_even = o_even * cos + o_odd * sin + inv_odd = -o_even * sin + o_odd * cos + result = o.clone() + rope_out = torch.empty_like(o_rope) + rope_out[..., 0::2] = inv_even + rope_out[..., 1::2] = inv_odd + result[:, :, nope:] = rope_out.to(torch.bfloat16) + return result + + +def validate_layer(li, all_weights, cfg, device='cuda:0'): + """Validate a single layer by running both forward_layer and a step-by-step reference.""" + from single_shot_inference import forward_layer, mHCBlock, RMSNorm, SimpleKVCache + + n_h = cfg["num_attention_heads"] + hd = cfg["head_dim"] + rd = cfg.get("qk_rope_head_dim", 64) + H = cfg["hidden_size"] + o_groups = cfg.get("o_groups", 16) + o_rank = cfg.get("o_group_dim", 1024) + n_hc = 4 + pre = f"model.layers.{li}.self_attn" + + # Get weights + w = all_weights # Already filtered and on device + + # Build RoPE caches + rope_params = cfg.get("rope_parameters", {}) + rope_type = rope_params.get("rope_type", "yarn") + rope_factor = rope_params.get("factor", 16.0) + rope_theta = rope_params.get("rope_theta", cfg.get("rope_theta", 10000.0)) + original_max_pos = rope_params.get("original_max_position_embeddings", 4096) + beta_fast = rope_params.get("beta_fast", 32) + beta_slow = rope_params.get("beta_slow", 1) + rope_cos, rope_sin = build_rope_cache( + 8192, rd, device, theta=rope_theta, + rope_type=rope_type, rope_factor=rope_factor, + original_max_pos=original_max_pos, + beta_fast=beta_fast, beta_slow=beta_slow + ) + + # Create mHC blocks + attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device) + ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device) + fn_key = f"model.layers.{li}.attn_hc.fn" + base_key = f"model.layers.{li}.attn_hc.base" + scale_key = f"model.layers.{li}.attn_hc.scale" + if fn_key in w and base_key in w and scale_key in w: + attn_mhc.load_from_checkpoint(w[fn_key], w[base_key], w[scale_key]) + fn_key = f"model.layers.{li}.ffn_hc.fn" + base_key = f"model.layers.{li}.ffn_hc.base" + scale_key = f"model.layers.{li}.ffn_hc.scale" + if fn_key in w and base_key in w and scale_key in w: + ffn_mhc.load_from_checkpoint(w[fn_key], w[base_key], w[scale_key]) + + attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device) + an_key = f"model.layers.{li}.input_layernorm.weight" + if an_key in w: + attn_norm.weight = w[an_key].to(device=device, dtype=torch.float32) + + ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device) + fn_key = f"model.layers.{li}.post_attention_layernorm.weight" + if fn_key in w: + ffn_norm.weight = w[fn_key].to(device=device, dtype=torch.float32) + + kv_cache = SimpleKVCache(head_dim=hd, max_seq=8192, device=device) + + # Create input: random embedding + torch.manual_seed(42) + X_l = torch.randn(1, n_hc, H, dtype=torch.bfloat16, device=device) * 0.5 + positions = torch.tensor([0], dtype=torch.long, device=device) + token_id = torch.tensor([671], dtype=torch.long, device=device) # "The" + + # Run forward_layer (production code) + X_prod = forward_layer(X_l.clone(), w, li, cfg, rope_cos, rope_sin, + attn_mhc, ffn_mhc, attn_norm, ffn_norm, + kv_cache, token_id, positions) + + print(f"Production: |X_next|={X_prod.abs().max().item():.4f}") + print(f" Stream norms: {[X_prod[0,s,:].float().norm().item() for s in range(4)]}") + + # ============================================================ + # Step-by-step reference (same math, no wrappers) + # ============================================================ + X_l_ref = X_l.clone() + + # --- mHC pre_block (attention) --- + fn = w[f"model.layers.{li}.attn_hc.fn"].float().to(device) + base = w[f"model.layers.{li}.attn_hc.base"].float().to(device) + scale = w[f"model.layers.{li}.attn_hc.scale"].float().to(device) + + # Unweighted RMSNorm on flattened residual + X_flat = X_l_ref.reshape(1, n_hc * H).float() + rms_inv = X_flat.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + flat = (X_flat * rms_inv).to(torch.bfloat16) + + # F.linear + split [pre(4), post(4), comb(16)] + proj = torch.nn.functional.linear(flat, fn).float() + pre_w, post_w, comb_w = proj.split([n_hc, n_hc, n_hc * n_hc], dim=-1) + pre_b, post_b, comb_b = base.split([n_hc, n_hc, n_hc * n_hc]) + pre_s, post_s, comb_s = scale.unbind(0) + + A_l = torch.sigmoid(pre_w * pre_s + pre_b) + 1e-6 + C_l = 2.0 * torch.sigmoid(post_w * post_s + post_b) + B_l = sinkhorn((comb_w * comb_s + comb_b).reshape(1, n_hc, n_hc)) + + x_in = (A_l.unsqueeze(-1) * X_l_ref.float()).sum(dim=1).to(torch.bfloat16) + print(f"\n mHC A_l: {A_l[0].tolist()}") + print(f" mHC C_l: {C_l[0].tolist()}") + print(f" B row sums: {B_l[0].sum(-1).tolist()}") + + # --- RMSNorm --- + x_normed = rmsnorm(x_in, w[f"model.layers.{li}.input_layernorm.weight"].to(device)) + + # --- Q projection --- + c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) + c_Q = rmsnorm(c_Q, w[f"{pre}.q_a_norm.weight"].to(device)) + q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) + q = unweighted_rmsnorm(q) + + # --- KV projection --- + kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) + kv = rmsnorm(kv, w[f"{pre}.kv_norm.weight"].to(device)) + + q_heads = q.reshape(1, n_h, hd) + kv_new = kv.reshape(1, 1, hd) + + # RoPE + q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd) + kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd) + + # Attention (single token → self-attention is identity) + attn_out = kv_new.expand(1, n_h, hd) + + # Inverse RoPE + attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd) + + # Output projection + attn_flat = attn_out.reshape(1, n_h * hd) + attn_grouped = attn_flat.reshape(1, o_groups, (n_h // o_groups) * hd) + oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16().to(device) + oa_3d = oa_w.reshape(o_groups, o_rank, (n_h // o_groups) * hd) + grouped_out = torch.bmm(attn_grouped.permute(1, 0, 2), oa_3d.transpose(1, 2)) + grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) + F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"]) + + print(f" |F_attn|={F_attn.abs().max().item():.4f}") + + # mHC post_block: X_mid = C_l * F_attn + B_l.T @ X_l + BX = torch.bmm(B_l.transpose(-1, -2), X_l_ref.float()) + CF = C_l.unsqueeze(-1) * F_attn.unsqueeze(1) + X_mid = (CF.float() + BX).to(torch.bfloat16) + + print(f" |X_mid|={X_mid.abs().max().item():.4f}") + print(f" Stream norms (mid): {[X_mid[0,s,:].float().norm().item() for s in range(4)]}") + + # Compare X_mid with production + X_prod_mid_stream0 = X_prod[0, 0, :].float() + X_ref_mid_stream0 = X_mid[0, 0, :].float() + cos_sim = torch.nn.functional.cosine_similarity(X_prod_mid_stream0.unsqueeze(0), X_ref_mid_stream0.unsqueeze(0)).item() + max_diff = (X_prod - X_mid).abs().max().item() + print(f"\n Stream 0 cosine similarity: {cos_sim:.6f}") + print(f" Max diff: {max_diff:.6f}") + + if cos_sim < 0.99: + print(" ⚠️ MISMATCH! Production and reference differ significantly!") + else: + print(" ✅ Match!") + + return X_mid + + +def main(): + from safetensors.torch import load_file + + p = argparse.ArgumentParser() + p.add_argument('--layer', type=int, default=0) + p.add_argument('--device', type=str, default='cuda:0') + args = p.parse_args() + + with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: + cfg = json.load(f) + + print(f"Loading weights for layer {args.layer}...") + cdir = Path(CHECKPOINT_DIR) + index_path = cdir / "model.safetensors.index.json" + weight_map = {} + if index_path.exists(): + with open(index_path) as f: + weight_map = json.load(f).get("weight_map", {}) + + # Find which shards contain our layer + li = args.layer + prefix = f"model.layers.{li}." + needed_shards = set() + for key, shard in weight_map.items(): + if key.startswith(prefix) or key in ["model.embed_tokens.weight", "model.norm.weight"]: + needed_shards.add(shard) + + all_w = {} + for shard_name in sorted(needed_shards): + if not (cdir / shard_name).exists(): + continue + data = load_file(str(cdir / shard_name)) + for k, v in data.items(): + if k.startswith(prefix) or k in ["model.embed_tokens.weight"]: + all_w[k] = v.to(device=args.device, non_blocking=True) + + print(f" {len(all_w)} weights loaded") + validate_layer(li, all_w, cfg, device=args.device) + + +if __name__ == "__main__": + main()