From 2333fc8b4bef09c8ec3cb2dccc960242edf84dfe Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 05:53:49 +0000 Subject: [PATCH] fix verify_attention.py: proper nvfp4_linear calls --- tests/verify_attention.py | 167 ++++++++++++++++++++++---------------- 1 file changed, 99 insertions(+), 68 deletions(-) diff --git a/tests/verify_attention.py b/tests/verify_attention.py index d7c8fbfb..816360e5 100644 --- a/tests/verify_attention.py +++ b/tests/verify_attention.py @@ -16,7 +16,8 @@ LAYER_IDX = 0 # Correct E2M1 magnitudes FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) -def dequant_nvfp4(weight, weight_scale, weight_scale_2): +def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2): + """Dequantize NVFP4 weight to BF16.""" out_dim = weight.shape[0] in_packed = weight.shape[1] in_features = in_packed * 2 @@ -33,6 +34,18 @@ def dequant_nvfp4(weight, weight_scale, weight_scale_2): return (w_f * scale_expanded).bfloat16() +def nvfp4_linear(x, weight, weight_scale, weight_scale_2): + """BF16 linear with NVFP4 dequant.""" + w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2) + return torch.nn.functional.linear(x, w) + + +def rmsnorm(x, weight, eps=1e-6): + x_f = x.float() + rms = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() + return (x_f * rms * weight.float()).bfloat16() + + def main(): with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) @@ -58,62 +71,70 @@ def main(): for key, shard in wm.items(): if key.startswith(f"model.layers.{LAYER_IDX}.self_attn.") and "compressor" not in key and "indexer" not in key: data = load_file(str(cdir / shard)) - w[key.split(f"model.layers.{LAYER_IDX}.")[1]] = data[key].cuda() + w[key] = data[key].cuda() - print("Loaded weights:") - for k, v in sorted(w.items()): - print(f" {k}: {v.shape} {v.dtype}") + print("Loaded attention weights:") + for k in sorted(w.keys()): + if "self_attn" in k: + print(f" ...{k.split('self_attn.')[1]}: {w[k].shape} {w[k].dtype}") - # Create input: random hidden state after RMSNorm (unit scale) + # Create input: random hidden state torch.manual_seed(42) - x = torch.randn(1, H, dtype=torch.bfloat16, device='cuda:0') - # RMSNorm - x_f = x.float() - rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - x_normed = (x_f * rms).bfloat16() + x_raw = torch.randn(1, H, dtype=torch.bfloat16, device='cuda:0') + print(f"\nInput |x_raw| = {x_raw.abs().max():.4f}") - # === Q projection === - c_Q = dequant_nvfp4(x_normed, w["self_attn.q_a_proj.weight"], - w["self_attn.q_a_proj.weight_scale"], - w["self_attn.q_a_proj.weight_scale_2"]) - print(f"\nc_Q: shape={c_Q.shape}, |c_Q|={c_Q.abs().max():.4f}, mean={c_Q.float().mean():.4f}") + # RMSNorm + input_layernorm_w = None + for k in w: + if 'input_layernorm' in k: + input_layernorm_w = w[k] + break + if input_layernorm_w is not None: + x_normed = rmsnorm(x_raw, input_layernorm_w) + else: + x_normed = x_raw + print(f"After RMSNorm |x_normed| = {x_normed.abs().max():.4f}") + + # === Q projection: q_a → q_a_norm → q_b === + c_Q = nvfp4_linear(x_normed, + w[f"{pre}.q_a_proj.weight"], + w[f"{pre}.q_a_proj.weight_scale"], + w[f"{pre}.q_a_proj.weight_scale_2"]) + print(f"\nc_Q: shape={c_Q.shape}, |c_Q|={c_Q.abs().max():.4f}") # q_a_norm - q_norm_w = w["self_attn.q_a_norm.weight"] - c_Q_f = c_Q.float() - c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16() - print(f"After q_a_norm: |c_Q|={c_Q.abs().max():.4f}") + if f"{pre}.q_a_norm.weight" in w: + c_Q = rmsnorm(c_Q, w[f"{pre}.q_a_norm.weight"]) + print(f"After q_a_norm: |c_Q|={c_Q.abs().max():.4f}") - q = dequant_nvfp4(c_Q, w["self_attn.q_b_proj.weight"], - w["self_attn.q_b_proj.weight_scale"], - w["self_attn.q_b_proj.weight_scale_2"]) - print(f"q: shape={q.shape}, |q|={q.abs().max():.4f}, mean={q.float().mean():.4f}") + q = nvfp4_linear(c_Q, + w[f"{pre}.q_b_proj.weight"], + w[f"{pre}.q_b_proj.weight_scale"], + w[f"{pre}.q_b_proj.weight_scale_2"]) + print(f"q: shape={q.shape}, |q|={q.abs().max():.4f}") q_heads = q.reshape(1, n_h, hd) - print(f"q_heads: shape={q_heads.shape}, per-head norm={q_heads[0, 0].float().norm():.4f}") + print(f"q_heads: shape={q_heads.shape}") - # === KV projection === - kv = dequant_nvfp4(x_normed, w["self_attn.kv_proj.weight"], - w["self_attn.kv_proj.weight_scale"], - w["self_attn.kv_proj.weight_scale_2"]) - print(f"\nkv: shape={kv.shape}, |kv|={kv.abs().max():.4f}, mean={kv.float().mean():.4f}") + # === KV projection + kv_norm === + kv = nvfp4_linear(x_normed, + w[f"{pre}.kv_proj.weight"], + w[f"{pre}.kv_proj.weight_scale"], + w[f"{pre}.kv_proj.weight_scale_2"]) + print(f"\nkv: shape={kv.shape}, |kv|={kv.abs().max():.4f}") - # kv_norm - kv_norm_w = w["self_attn.kv_norm.weight"] - kv_f = kv.float() - kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16() - print(f"After kv_norm: |kv|={kv.abs().max():.4f}") + if f"{pre}.kv_norm.weight" in w: + kv = rmsnorm(kv, w[f"{pre}.kv_norm.weight"]) + print(f"After kv_norm: |kv|={kv.abs().max():.4f}") kv_heads = kv.reshape(1, 1, hd) # 1 KV head print(f"kv_heads: shape={kv_heads.shape}") # === Apply RoPE === half = rd // 2 - freqs = 1.0 / (10000.0 ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) - pos = torch.tensor([0], dtype=torch.long) - cos = torch.cos(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16() + freqs = 1.0 / (10000.0 ** (torch.arange(0, rd, 2, dtype=torch.float32, device='cuda') / rd)) + pos = torch.tensor([0], dtype=torch.long, device='cuda') + cos = torch.cos(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16() # (1, half) sin = torch.sin(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16() def apply_rope(x, cos, sin): @@ -124,18 +145,6 @@ def main(): out[:, :, nope:][..., 1::2] = x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos return out - q_roped = apply_rope(q_heads, cos.unsqueeze(0), sin.unsqueeze(0)) - kv_roped = apply_rope(kv_heads, cos.unsqueeze(0), sin.unsqueeze(0)) - - print(f"\nAfter RoPE: |q|={q_roped.abs().max():.4f}, |kv|={kv_roped.abs().max():.4f}") - - # === Attention (single KV entry → output = V) === - # For 1 KV entry, attention output = V (softmax of scalar = 1) - # With K=V (both RoPE'd), output = V_roped - # Then inverse RoPE should give back kv (pre-RoPE) - attn_out = kv_roped # (1, 1, hd) — just V - - # Inverse RoPE def apply_inverse_rope(o, cos, sin): nope = hd - rd out = o.clone() @@ -144,11 +153,20 @@ def main(): out[:, :, nope:][..., 1::2] = -o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos return out + q_roped = apply_rope(q_heads, cos.unsqueeze(0), sin.unsqueeze(0)) + kv_roped = apply_rope(kv_heads, cos.unsqueeze(0), sin.unsqueeze(0)) + + # === Attention (single KV entry → output = V) === + # For 1 KV entry, attention output = V (softmax of scalar = 1) + # With K=V (both RoPE'd), output = V_roped + # Then inverse RoPE should recover kv (pre-RoPE) + attn_out = kv_roped # (1, 1, hd) — just V + attn_out_inv = apply_inverse_rope(attn_out, cos.unsqueeze(0), sin.unsqueeze(0)) - # Check: inverse RoPE should recover the original kv (for single position) + # Check inverse RoPE recovery diff = (attn_out_inv[0, 0].float() - kv_heads[0, 0].float()).abs().max() - print(f"Inverse RoPE recovery: max diff = {diff:.6f} (should be ~0)") + print(f"\nInverse RoPE recovery: max diff = {diff:.6f} (should be ~0)") # === Output projection === attn_flat = attn_out_inv.reshape(1, n_h * hd) # (1, 65536) @@ -156,7 +174,7 @@ def main(): # wo_a: grouped linear attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) # (1, 16, 4096) - oa_w = w["self_attn.o_a_proj.weight"].bfloat16() # (16384, 4096) + oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() # (16384, 4096) oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (16, 1024, 4096) attn_for_bmm = attn_grouped.permute(1, 0, 2) # (16, 1, 4096) grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (16, 1, 1024) @@ -165,21 +183,34 @@ def main(): print(f"grouped_flat: shape={grouped_flat.shape}, |grouped_flat|={grouped_flat.abs().max():.4f}") # wo_b - F_attn = dequant_nvfp4(grouped_flat, w["self_attn.o_b_proj.weight"], - w["self_attn.o_b_proj.weight_scale"], - w["self_attn.o_b_proj.weight_scale_2"]) + F_attn = nvfp4_linear(grouped_flat, + w[f"{pre}.o_b_proj.weight"], + w[f"{pre}.o_b_proj.weight_scale"], + w[f"{pre}.o_b_proj.weight_scale_2"]) print(f"F_attn: shape={F_attn.shape}, |F_attn|={F_attn.abs().max():.4f}") - # Sanity: check that the output is on a reasonable scale - print(f"\n=== SUMMARY ===") - print(f"Input |x| = {x.abs().max():.4f}") - print(f"After norm |x_normed| = {x_normed.abs().max():.4f}") - print(f"Q latent |c_Q| = {c_Q.abs().max():.4f}") - print(f"Q heads |q| = {q.abs().max():.4f}") - print(f"KV |kv| = {kv.abs().max():.4f}") - print(f"Attn output (pre-proj) |attn| = {attn_out_inv.abs().max():.4f}") - print(f"F_attn (post-proj) |F| = {F_attn.abs().max():.4f}") - print(f"Scale ratio F_attn/x = {F_attn.abs().max()/x.abs().max():.4f}") + # === Sanity checks === + print(f"\n{'='*50}") + print(f"ATTENTION SUB-BLOCK SUMMARY (Layer {LAYER_IDX})") + print(f"{'='*50}") + print(f"Input |x_normed| = {x_normed.abs().max():.4f}") + print(f"Q latent |c_Q| = {c_Q.abs().max():.4f}") + print(f"Q heads |q| = {q.abs().max():.4f}") + print(f"KV |kv| = {kv.abs().max():.4f}") + print(f"Attn out |attn_inv| = {attn_out_inv.abs().max():.4f}") + print(f"Grouped |grouped| = {grouped_flat.abs().max():.4f}") + print(f"F_attn (output) |F| = {F_attn.abs().max():.4f}") + print(f"Scale ratio F/x_norm = {F_attn.abs().max()/max(x_normed.abs().max(), 1e-8):.4f}") + + # Check: is F_attn reasonable? + if F_attn.abs().max() > 100: + print(f"\n⚠️ WARNING: F_attn is very large ({F_attn.abs().max():.1f}). " + f"This will cause residual growth in the full model.") + elif F_attn.abs().max() < 0.01: + print(f"\n⚠️ WARNING: F_attn is very small. " + f"Attention output is being suppressed.") + else: + print(f"\n✅ F_attn is on a reasonable scale.") if __name__ == "__main__":