#!/usr/bin/env python3 """Verify single-layer attention output matches a PyTorch reference. This script: 1. Loads layer 0 weights 2. Processes one token through the attention sub-block 3. Compares with a simple PyTorch reference 4. Identifies where the outputs diverge """ import os, sys, json, math, torch from pathlib import Path CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" LAYER_IDX = 0 # Correct E2M1 magnitudes FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2): """Dequantize NVFP4 weight to BF16.""" out_dim = weight.shape[0] in_packed = weight.shape[1] in_features = in_packed * 2 low = (weight & 0x0F).to(torch.int8) high = (weight >> 4).to(torch.int8) low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long() high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long() lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0) w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) scale_f = weight_scale.float() * weight_scale_2.float() scale_expanded = scale_f.repeat_interleave(16, dim=1) return (w_f * scale_expanded).bfloat16() def nvfp4_linear(x, weight, weight_scale, weight_scale_2): """BF16 linear with NVFP4 dequant.""" w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2) return torch.nn.functional.linear(x, w) def rmsnorm(x, weight, eps=1e-6): x_f = x.float() rms = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() return (x_f * rms * weight.float()).bfloat16() def main(): with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) n_h = cfg["num_attention_heads"] # 128 hd = cfg["head_dim"] # 512 H = cfg["hidden_size"] # 7168 rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64)) # 64 o_rank = cfg.get("output_group_dim", 1024) o_groups = cfg.get("num_output_groups", 16) heads_per_group = n_h // o_groups # 8 group_input_dim = heads_per_group * hd # 4096 pre = f"model.layers.{LAYER_IDX}.self_attn" # Load weights for this layer from safetensors.torch import load_file cdir = Path(CHECKPOINT_DIR) with open(cdir / "model.safetensors.index.json") as f: wm = json.load(f)["weight_map"] w = {} for key, shard in wm.items(): if key.startswith(f"model.layers.{LAYER_IDX}.self_attn.") and "compressor" not in key and "indexer" not in key: data = load_file(str(cdir / shard)) w[key] = data[key].cuda() print("Loaded attention weights:") for k in sorted(w.keys()): if "self_attn" in k: print(f" ...{k.split('self_attn.')[1]}: {w[k].shape} {w[k].dtype}") # Create input: random hidden state torch.manual_seed(42) x_raw = torch.randn(1, H, dtype=torch.bfloat16, device='cuda:0') print(f"\nInput |x_raw| = {x_raw.abs().max():.4f}") # RMSNorm input_layernorm_w = None for k in w: if 'input_layernorm' in k: input_layernorm_w = w[k] break if input_layernorm_w is not None: x_normed = rmsnorm(x_raw, input_layernorm_w) else: x_normed = x_raw print(f"After RMSNorm |x_normed| = {x_normed.abs().max():.4f}") # === Q projection: q_a → q_a_norm → q_b === c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) print(f"\nc_Q: shape={c_Q.shape}, |c_Q|={c_Q.abs().max():.4f}") # q_a_norm if f"{pre}.q_a_norm.weight" in w: c_Q = rmsnorm(c_Q, w[f"{pre}.q_a_norm.weight"]) print(f"After q_a_norm: |c_Q|={c_Q.abs().max():.4f}") q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) print(f"q: shape={q.shape}, |q|={q.abs().max():.4f}") q_heads = q.reshape(1, n_h, hd) print(f"q_heads: shape={q_heads.shape}") # === KV projection + kv_norm === kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) print(f"\nkv: shape={kv.shape}, |kv|={kv.abs().max():.4f}") if f"{pre}.kv_norm.weight" in w: kv = rmsnorm(kv, w[f"{pre}.kv_norm.weight"]) print(f"After kv_norm: |kv|={kv.abs().max():.4f}") kv_heads = kv.reshape(1, 1, hd) # 1 KV head print(f"kv_heads: shape={kv_heads.shape}") # === Apply RoPE === half = rd // 2 freqs = 1.0 / (10000.0 ** (torch.arange(0, rd, 2, dtype=torch.float32, device='cuda') / rd)) pos = torch.tensor([0], dtype=torch.long, device='cuda') cos = torch.cos(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16() # (1, half) sin = torch.sin(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16() def apply_rope(x, cos, sin): nope = hd - rd out = x.clone() x_rope = x[:, :, nope:] out[:, :, nope:][..., 0::2] = x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin out[:, :, nope:][..., 1::2] = x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos return out def apply_inverse_rope(o, cos, sin): nope = hd - rd out = o.clone() o_rope = o[:, :, nope:] out[:, :, nope:][..., 0::2] = o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin out[:, :, nope:][..., 1::2] = -o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos return out q_roped = apply_rope(q_heads, cos.unsqueeze(0), sin.unsqueeze(0)) kv_roped = apply_rope(kv_heads, cos.unsqueeze(0), sin.unsqueeze(0)) # === Attention (single KV entry → output = V) === # For 1 KV entry, attention output = V (softmax of scalar = 1) # With K=V (both RoPE'd), output = V_roped # Then inverse RoPE should recover kv (pre-RoPE) attn_out = kv_roped # (1, 1, hd) — just V attn_out_inv = apply_inverse_rope(attn_out, cos.unsqueeze(0), sin.unsqueeze(0)) # Check inverse RoPE recovery diff = (attn_out_inv[0, 0].float() - kv_heads[0, 0].float()).abs().max() print(f"\nInverse RoPE recovery: max diff = {diff:.6f} (should be ~0)") # === Output projection === # For GQA, the attention output is (n_h, T, hd) # Each Q head attended to the same KV, producing its own output # For this test with 1 KV entry, all heads produce the same V # In practice, each head has different Q, so different attention weights # Let's use the actual multi-head attention output # Proper multi-head attention with SDPA q_input = q_roped # (1, n_h, hd) = (1, 128, 512) k_input = kv_roped.expand(n_h, -1, -1) # (n_h, 1, hd) = (128, 1, 512) v_input = kv_roped.expand(n_h, -1, -1) attn_out_full = torch.nn.functional.scaled_dot_product_attention( q_input, k_input, v_input, scale=1.0 / math.sqrt(hd), is_causal=False) # (1, n_h, hd) # Wait, shapes are (1, 128, 512) for q but (128, 1, 512) for k/v # Need to fix: q is (1, n_h, hd) → permute to (n_h, 1, hd) q_input = q_roped.permute(1, 0, 2) # (n_h, 1, hd) = (128, 1, 512) k_input = kv_roped.squeeze(0).expand(n_h, -1, -1) # (n_h, 1, hd) v_input = kv_roped.squeeze(0).expand(n_h, -1, -1) # (n_h, 1, hd) attn_out_full = torch.nn.functional.scaled_dot_product_attention( q_input, k_input, v_input, scale=1.0 / math.sqrt(hd), is_causal=False) # (n_h, 1, hd) attn_out_full = attn_out_full.permute(1, 0, 2) # (1, n_h, hd) # Inverse RoPE attn_out_inv = apply_inverse_rope(attn_out_full, cos.unsqueeze(0), sin.unsqueeze(0)) print(f"\nMulti-head attn output: shape={attn_out_inv.shape}, |attn|={attn_out_inv.abs().max():.4f}") # Compare: per-head output should be close to kv (since 1 KV entry) for h in [0, 1, 63, 127]: diff = (attn_out_inv[0, h].float() - kv_heads[0, 0].float()).abs().max() print(f" Head {h}: max diff from kv = {diff:.6f}") attn_flat = attn_out_inv.reshape(1, n_h * hd) # (1, 65536) print(f"\nattn_flat: shape={attn_flat.shape}, |attn_flat|={attn_flat.abs().max():.4f}") # wo_a: grouped linear attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) # (1, 16, 4096) oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() # (16384, 4096) oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (16, 1024, 4096) attn_for_bmm = attn_grouped.permute(1, 0, 2) # (16, 1, 4096) grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (16, 1, 1024) grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) # (1, 16384) print(f"grouped_flat: shape={grouped_flat.shape}, |grouped_flat|={grouped_flat.abs().max():.4f}") # wo_b F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"]) print(f"F_attn: shape={F_attn.shape}, |F_attn|={F_attn.abs().max():.4f}") # === Sanity checks === print(f"\n{'='*50}") print(f"ATTENTION SUB-BLOCK SUMMARY (Layer {LAYER_IDX})") print(f"{'='*50}") print(f"Input |x_normed| = {x_normed.abs().max():.4f}") print(f"Q latent |c_Q| = {c_Q.abs().max():.4f}") print(f"Q heads |q| = {q.abs().max():.4f}") print(f"KV |kv| = {kv.abs().max():.4f}") print(f"Attn out |attn_inv| = {attn_out_inv.abs().max():.4f}") print(f"Grouped |grouped| = {grouped_flat.abs().max():.4f}") print(f"F_attn (output) |F| = {F_attn.abs().max():.4f}") print(f"Scale ratio F/x_norm = {F_attn.abs().max()/max(x_normed.abs().max(), 1e-8):.4f}") # Check: is F_attn reasonable? if F_attn.abs().max() > 100: print(f"\n⚠️ WARNING: F_attn is very large ({F_attn.abs().max():.1f}). " f"This will cause residual growth in the full model.") elif F_attn.abs().max() < 0.01: print(f"\n⚠️ WARNING: F_attn is very small. " f"Attention output is being suppressed.") else: print(f"\n✅ F_attn is on a reasonable scale.") if __name__ == "__main__": main()