fix verify_attention.py: proper nvfp4_linear calls

This commit is contained in:
2026-05-31 05:53:49 +00:00
parent c09f68c867
commit 2333fc8b4b

View File

@@ -16,7 +16,8 @@ LAYER_IDX = 0
# Correct E2M1 magnitudes
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
def dequant_nvfp4(weight, weight_scale, weight_scale_2):
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
"""Dequantize NVFP4 weight to BF16."""
out_dim = weight.shape[0]
in_packed = weight.shape[1]
in_features = in_packed * 2
@@ -33,6 +34,18 @@ def dequant_nvfp4(weight, weight_scale, weight_scale_2):
return (w_f * scale_expanded).bfloat16()
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
"""BF16 linear with NVFP4 dequant."""
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
return torch.nn.functional.linear(x, w)
def rmsnorm(x, weight, eps=1e-6):
x_f = x.float()
rms = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
return (x_f * rms * weight.float()).bfloat16()
def main():
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
@@ -58,62 +71,70 @@ def main():
for key, shard in wm.items():
if key.startswith(f"model.layers.{LAYER_IDX}.self_attn.") and "compressor" not in key and "indexer" not in key:
data = load_file(str(cdir / shard))
w[key.split(f"model.layers.{LAYER_IDX}.")[1]] = data[key].cuda()
w[key] = data[key].cuda()
print("Loaded weights:")
for k, v in sorted(w.items()):
print(f" {k}: {v.shape} {v.dtype}")
print("Loaded attention weights:")
for k in sorted(w.keys()):
if "self_attn" in k:
print(f" ...{k.split('self_attn.')[1]}: {w[k].shape} {w[k].dtype}")
# Create input: random hidden state after RMSNorm (unit scale)
# Create input: random hidden state
torch.manual_seed(42)
x = torch.randn(1, H, dtype=torch.bfloat16, device='cuda:0')
# RMSNorm
x_f = x.float()
rms = x_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x_normed = (x_f * rms).bfloat16()
x_raw = torch.randn(1, H, dtype=torch.bfloat16, device='cuda:0')
print(f"\nInput |x_raw| = {x_raw.abs().max():.4f}")
# === Q projection ===
c_Q = dequant_nvfp4(x_normed, w["self_attn.q_a_proj.weight"],
w["self_attn.q_a_proj.weight_scale"],
w["self_attn.q_a_proj.weight_scale_2"])
print(f"\nc_Q: shape={c_Q.shape}, |c_Q|={c_Q.abs().max():.4f}, mean={c_Q.float().mean():.4f}")
# RMSNorm
input_layernorm_w = None
for k in w:
if 'input_layernorm' in k:
input_layernorm_w = w[k]
break
if input_layernorm_w is not None:
x_normed = rmsnorm(x_raw, input_layernorm_w)
else:
x_normed = x_raw
print(f"After RMSNorm |x_normed| = {x_normed.abs().max():.4f}")
# === Q projection: q_a → q_a_norm → q_b ===
c_Q = nvfp4_linear(x_normed,
w[f"{pre}.q_a_proj.weight"],
w[f"{pre}.q_a_proj.weight_scale"],
w[f"{pre}.q_a_proj.weight_scale_2"])
print(f"\nc_Q: shape={c_Q.shape}, |c_Q|={c_Q.abs().max():.4f}")
# q_a_norm
q_norm_w = w["self_attn.q_a_norm.weight"]
c_Q_f = c_Q.float()
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16()
print(f"After q_a_norm: |c_Q|={c_Q.abs().max():.4f}")
if f"{pre}.q_a_norm.weight" in w:
c_Q = rmsnorm(c_Q, w[f"{pre}.q_a_norm.weight"])
print(f"After q_a_norm: |c_Q|={c_Q.abs().max():.4f}")
q = dequant_nvfp4(c_Q, w["self_attn.q_b_proj.weight"],
w["self_attn.q_b_proj.weight_scale"],
w["self_attn.q_b_proj.weight_scale_2"])
print(f"q: shape={q.shape}, |q|={q.abs().max():.4f}, mean={q.float().mean():.4f}")
q = nvfp4_linear(c_Q,
w[f"{pre}.q_b_proj.weight"],
w[f"{pre}.q_b_proj.weight_scale"],
w[f"{pre}.q_b_proj.weight_scale_2"])
print(f"q: shape={q.shape}, |q|={q.abs().max():.4f}")
q_heads = q.reshape(1, n_h, hd)
print(f"q_heads: shape={q_heads.shape}, per-head norm={q_heads[0, 0].float().norm():.4f}")
print(f"q_heads: shape={q_heads.shape}")
# === KV projection ===
kv = dequant_nvfp4(x_normed, w["self_attn.kv_proj.weight"],
w["self_attn.kv_proj.weight_scale"],
w["self_attn.kv_proj.weight_scale_2"])
print(f"\nkv: shape={kv.shape}, |kv|={kv.abs().max():.4f}, mean={kv.float().mean():.4f}")
# === KV projection + kv_norm ===
kv = nvfp4_linear(x_normed,
w[f"{pre}.kv_proj.weight"],
w[f"{pre}.kv_proj.weight_scale"],
w[f"{pre}.kv_proj.weight_scale_2"])
print(f"\nkv: shape={kv.shape}, |kv|={kv.abs().max():.4f}")
# kv_norm
kv_norm_w = w["self_attn.kv_norm.weight"]
kv_f = kv.float()
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16()
print(f"After kv_norm: |kv|={kv.abs().max():.4f}")
if f"{pre}.kv_norm.weight" in w:
kv = rmsnorm(kv, w[f"{pre}.kv_norm.weight"])
print(f"After kv_norm: |kv|={kv.abs().max():.4f}")
kv_heads = kv.reshape(1, 1, hd) # 1 KV head
print(f"kv_heads: shape={kv_heads.shape}")
# === Apply RoPE ===
half = rd // 2
freqs = 1.0 / (10000.0 ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd))
pos = torch.tensor([0], dtype=torch.long)
cos = torch.cos(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16()
freqs = 1.0 / (10000.0 ** (torch.arange(0, rd, 2, dtype=torch.float32, device='cuda') / rd))
pos = torch.tensor([0], dtype=torch.long, device='cuda')
cos = torch.cos(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16() # (1, half)
sin = torch.sin(pos.float().unsqueeze(1) * freqs.unsqueeze(0)).bfloat16()
def apply_rope(x, cos, sin):
@@ -124,18 +145,6 @@ def main():
out[:, :, nope:][..., 1::2] = x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos
return out
q_roped = apply_rope(q_heads, cos.unsqueeze(0), sin.unsqueeze(0))
kv_roped = apply_rope(kv_heads, cos.unsqueeze(0), sin.unsqueeze(0))
print(f"\nAfter RoPE: |q|={q_roped.abs().max():.4f}, |kv|={kv_roped.abs().max():.4f}")
# === Attention (single KV entry → output = V) ===
# For 1 KV entry, attention output = V (softmax of scalar = 1)
# With K=V (both RoPE'd), output = V_roped
# Then inverse RoPE should give back kv (pre-RoPE)
attn_out = kv_roped # (1, 1, hd) — just V
# Inverse RoPE
def apply_inverse_rope(o, cos, sin):
nope = hd - rd
out = o.clone()
@@ -144,11 +153,20 @@ def main():
out[:, :, nope:][..., 1::2] = -o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos
return out
q_roped = apply_rope(q_heads, cos.unsqueeze(0), sin.unsqueeze(0))
kv_roped = apply_rope(kv_heads, cos.unsqueeze(0), sin.unsqueeze(0))
# === Attention (single KV entry → output = V) ===
# For 1 KV entry, attention output = V (softmax of scalar = 1)
# With K=V (both RoPE'd), output = V_roped
# Then inverse RoPE should recover kv (pre-RoPE)
attn_out = kv_roped # (1, 1, hd) — just V
attn_out_inv = apply_inverse_rope(attn_out, cos.unsqueeze(0), sin.unsqueeze(0))
# Check: inverse RoPE should recover the original kv (for single position)
# Check inverse RoPE recovery
diff = (attn_out_inv[0, 0].float() - kv_heads[0, 0].float()).abs().max()
print(f"Inverse RoPE recovery: max diff = {diff:.6f} (should be ~0)")
print(f"\nInverse RoPE recovery: max diff = {diff:.6f} (should be ~0)")
# === Output projection ===
attn_flat = attn_out_inv.reshape(1, n_h * hd) # (1, 65536)
@@ -156,7 +174,7 @@ def main():
# wo_a: grouped linear
attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) # (1, 16, 4096)
oa_w = w["self_attn.o_a_proj.weight"].bfloat16() # (16384, 4096)
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() # (16384, 4096)
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (16, 1024, 4096)
attn_for_bmm = attn_grouped.permute(1, 0, 2) # (16, 1, 4096)
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (16, 1, 1024)
@@ -165,21 +183,34 @@ def main():
print(f"grouped_flat: shape={grouped_flat.shape}, |grouped_flat|={grouped_flat.abs().max():.4f}")
# wo_b
F_attn = dequant_nvfp4(grouped_flat, w["self_attn.o_b_proj.weight"],
w["self_attn.o_b_proj.weight_scale"],
w["self_attn.o_b_proj.weight_scale_2"])
F_attn = nvfp4_linear(grouped_flat,
w[f"{pre}.o_b_proj.weight"],
w[f"{pre}.o_b_proj.weight_scale"],
w[f"{pre}.o_b_proj.weight_scale_2"])
print(f"F_attn: shape={F_attn.shape}, |F_attn|={F_attn.abs().max():.4f}")
# Sanity: check that the output is on a reasonable scale
print(f"\n=== SUMMARY ===")
print(f"Input |x| = {x.abs().max():.4f}")
print(f"After norm |x_normed| = {x_normed.abs().max():.4f}")
print(f"Q latent |c_Q| = {c_Q.abs().max():.4f}")
print(f"Q heads |q| = {q.abs().max():.4f}")
print(f"KV |kv| = {kv.abs().max():.4f}")
print(f"Attn output (pre-proj) |attn| = {attn_out_inv.abs().max():.4f}")
print(f"F_attn (post-proj) |F| = {F_attn.abs().max():.4f}")
print(f"Scale ratio F_attn/x = {F_attn.abs().max()/x.abs().max():.4f}")
# === Sanity checks ===
print(f"\n{'='*50}")
print(f"ATTENTION SUB-BLOCK SUMMARY (Layer {LAYER_IDX})")
print(f"{'='*50}")
print(f"Input |x_normed| = {x_normed.abs().max():.4f}")
print(f"Q latent |c_Q| = {c_Q.abs().max():.4f}")
print(f"Q heads |q| = {q.abs().max():.4f}")
print(f"KV |kv| = {kv.abs().max():.4f}")
print(f"Attn out |attn_inv| = {attn_out_inv.abs().max():.4f}")
print(f"Grouped |grouped| = {grouped_flat.abs().max():.4f}")
print(f"F_attn (output) |F| = {F_attn.abs().max():.4f}")
print(f"Scale ratio F/x_norm = {F_attn.abs().max()/max(x_normed.abs().max(), 1e-8):.4f}")
# Check: is F_attn reasonable?
if F_attn.abs().max() > 100:
print(f"\n⚠️ WARNING: F_attn is very large ({F_attn.abs().max():.1f}). "
f"This will cause residual growth in the full model.")
elif F_attn.abs().max() < 0.01:
print(f"\n⚠️ WARNING: F_attn is very small. "
f"Attention output is being suppressed.")
else:
print(f"\n✅ F_attn is on a reasonable scale.")
if __name__ == "__main__":