diff --git a/tests/verify_attention.py b/tests/verify_attention.py new file mode 100644 index 00000000..d7c8fbfb --- /dev/null +++ b/tests/verify_attention.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +"""Verify single-layer attention output matches a PyTorch reference. + +This script: +1. Loads layer 0 weights +2. Processes one token through the attention sub-block +3. Compares with a simple PyTorch reference +4. Identifies where the outputs diverge +""" +import os, sys, json, math, torch +from pathlib import Path + +CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +LAYER_IDX = 0 + +# Correct E2M1 magnitudes +FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) + +def dequant_nvfp4(weight, weight_scale, weight_scale_2): + out_dim = weight.shape[0] + in_packed = weight.shape[1] + in_features = in_packed * 2 + low = (weight & 0x0F).to(torch.int8) + high = (weight >> 4).to(torch.int8) + low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long() + high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long() + lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) + low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) + high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0) + w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) + scale_f = weight_scale.float() * weight_scale_2.float() + scale_expanded = scale_f.repeat_interleave(16, dim=1) + return (w_f * scale_expanded).bfloat16() + + +def main(): + with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: + cfg = json.load(f) + + n_h = cfg["num_attention_heads"] # 128 + hd = cfg["head_dim"] # 512 + H = cfg["hidden_size"] # 7168 + rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64)) # 64 + o_rank = cfg.get("output_group_dim", 1024) + o_groups = cfg.get("num_output_groups", 16) + heads_per_group = n_h // o_groups # 8 + group_input_dim = heads_per_group * hd # 4096 + + pre = f"model.layers.{LAYER_IDX}.self_attn" + + # Load weights for this layer + from safetensors.torch import load_file + cdir = Path(CHECKPOINT_DIR) + with open(cdir / "model.safetensors.index.json") as f: + wm = json.load(f)["weight_map"] + + w = {} + for key, shard in wm.items(): + if key.startswith(f"model.layers.{LAYER_IDX}.self_attn.") and "compressor" not in key and "indexer" not in key: + data = load_file(str(cdir / shard)) + w[key.split(f"model.layers.{LAYER_IDX}.")[1]] = data[key].cuda() + + print("Loaded weights:") + for k, v in sorted(w.items()): + print(f" {k}: {v.shape} {v.dtype}") + + # Create input: random hidden state after RMSNorm (unit scale) + torch.manual_seed(42) + x = torch.randn(1, H, dtype=torch.bfloat16, device='cuda:0') + # RMSNorm + x_f = x.float() + rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + x_normed = (x_f * rms).bfloat16() + + # === Q projection === + c_Q = dequant_nvfp4(x_normed, w["self_attn.q_a_proj.weight"], + w["self_attn.q_a_proj.weight_scale"], + w["self_attn.q_a_proj.weight_scale_2"]) + print(f"\nc_Q: shape={c_Q.shape}, |c_Q|={c_Q.abs().max():.4f}, mean={c_Q.float().mean():.4f}") + + # q_a_norm + q_norm_w = w["self_attn.q_a_norm.weight"] + c_Q_f = c_Q.float() + c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16() + print(f"After q_a_norm: |c_Q|={c_Q.abs().max():.4f}") + + q = dequant_nvfp4(c_Q, w["self_attn.q_b_proj.weight"], + w["self_attn.q_b_proj.weight_scale"], + w["self_attn.q_b_proj.weight_scale_2"]) + print(f"q: shape={q.shape}, |q|={q.abs().max():.4f}, mean={q.float().mean():.4f}") + + q_heads = q.reshape(1, n_h, hd) + print(f"q_heads: shape={q_heads.shape}, per-head norm={q_heads[0, 0].float().norm():.4f}") + + # === KV projection === + kv = dequant_nvfp4(x_normed, w["self_attn.kv_proj.weight"], + w["self_attn.kv_proj.weight_scale"], + w["self_attn.kv_proj.weight_scale_2"]) + print(f"\nkv: shape={kv.shape}, |kv|={kv.abs().max():.4f}, mean={kv.float().mean():.4f}") + + # kv_norm + kv_norm_w = w["self_attn.kv_norm.weight"] + kv_f = kv.float() + kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16() + print(f"After kv_norm: |kv|={kv.abs().max():.4f}") + + kv_heads = kv.reshape(1, 1, hd) # 1 KV head + print(f"kv_heads: shape={kv_heads.shape}") + + # === Apply RoPE === + half = rd // 2 + freqs = 1.0 / (10000.0 ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + pos = torch.tensor([0], dtype=torch.long) + cos = torch.cos(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16() + sin = torch.sin(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16() + + def apply_rope(x, cos, sin): + nope = hd - rd + out = x.clone() + x_rope = x[:, :, nope:] + out[:, :, nope:][..., 0::2] = x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin + out[:, :, nope:][..., 1::2] = x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos + return out + + q_roped = apply_rope(q_heads, cos.unsqueeze(0), sin.unsqueeze(0)) + kv_roped = apply_rope(kv_heads, cos.unsqueeze(0), sin.unsqueeze(0)) + + print(f"\nAfter RoPE: |q|={q_roped.abs().max():.4f}, |kv|={kv_roped.abs().max():.4f}") + + # === Attention (single KV entry → output = V) === + # For 1 KV entry, attention output = V (softmax of scalar = 1) + # With K=V (both RoPE'd), output = V_roped + # Then inverse RoPE should give back kv (pre-RoPE) + attn_out = kv_roped # (1, 1, hd) — just V + + # Inverse RoPE + def apply_inverse_rope(o, cos, sin): + nope = hd - rd + out = o.clone() + o_rope = o[:, :, nope:] + out[:, :, nope:][..., 0::2] = o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin + out[:, :, nope:][..., 1::2] = -o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos + return out + + attn_out_inv = apply_inverse_rope(attn_out, cos.unsqueeze(0), sin.unsqueeze(0)) + + # Check: inverse RoPE should recover the original kv (for single position) + diff = (attn_out_inv[0, 0].float() - kv_heads[0, 0].float()).abs().max() + print(f"Inverse RoPE recovery: max diff = {diff:.6f} (should be ~0)") + + # === Output projection === + attn_flat = attn_out_inv.reshape(1, n_h * hd) # (1, 65536) + print(f"\nattn_flat: shape={attn_flat.shape}, |attn_flat|={attn_flat.abs().max():.4f}") + + # wo_a: grouped linear + attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) # (1, 16, 4096) + oa_w = w["self_attn.o_a_proj.weight"].bfloat16() # (16384, 4096) + oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (16, 1024, 4096) + attn_for_bmm = attn_grouped.permute(1, 0, 2) # (16, 1, 4096) + grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (16, 1, 1024) + grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) # (1, 16384) + + print(f"grouped_flat: shape={grouped_flat.shape}, |grouped_flat|={grouped_flat.abs().max():.4f}") + + # wo_b + F_attn = dequant_nvfp4(grouped_flat, w["self_attn.o_b_proj.weight"], + w["self_attn.o_b_proj.weight_scale"], + w["self_attn.o_b_proj.weight_scale_2"]) + print(f"F_attn: shape={F_attn.shape}, |F_attn|={F_attn.abs().max():.4f}") + + # Sanity: check that the output is on a reasonable scale + print(f"\n=== SUMMARY ===") + print(f"Input |x| = {x.abs().max():.4f}") + print(f"After norm |x_normed| = {x_normed.abs().max():.4f}") + print(f"Q latent |c_Q| = {c_Q.abs().max():.4f}") + print(f"Q heads |q| = {q.abs().max():.4f}") + print(f"KV |kv| = {kv.abs().max():.4f}") + print(f"Attn output (pre-proj) |attn| = {attn_out_inv.abs().max():.4f}") + print(f"F_attn (post-proj) |F| = {F_attn.abs().max():.4f}") + print(f"Scale ratio F_attn/x = {F_attn.abs().max()/x.abs().max():.4f}") + + +if __name__ == "__main__": + main()