test: add per-head magnitude ratio diagnostics to decode FMHA test

This commit is contained in:
2026-06-03 04:50:23 +00:00
parent f5fa20c581
commit 9a9b347b2b

View File

@@ -452,13 +452,18 @@ def main():
mag_prod = o_prod.float().abs().max().item()
mag_ref = o_ref.float().abs().max().item()
# Per-head cosine
# Per-head cosine AND magnitude ratio
o_prod_h = o_prod.float().squeeze(1) # (n_h, hd)
o_ref_h = o_ref.float().squeeze(1)
per_head_cos = F.cosine_similarity(o_prod_h, o_ref_h, dim=-1)
per_head_mag_prod = o_prod_h.abs().max(dim=-1).values # (n_h,)
per_head_mag_ref = o_ref_h.abs().max(dim=-1).values # (n_h,)
per_head_mag_ratio = (per_head_mag_prod / (per_head_mag_ref + 1e-8)) # (n_h,)
min_head = per_head_cos.min().item()
mean_head = per_head_cos.mean().item()
worst_heads = per_head_cos.argsort()[:5]
# Find heads with worst magnitude ratio
worst_mag = per_head_mag_ratio.sub(1.0).abs().argsort(descending=True)[:5]
results[li] = {
'cos': cos_val, 'mag_prod': mag_prod, 'mag_ref': mag_ref,
@@ -471,7 +476,11 @@ def main():
print(f" L{li}: {status} cos={cos_val:.6f} min_head={min_head:.6f} mean_head={mean_head:.6f} "
f"|prod|={mag_prod:.4f} |ref|={mag_ref:.4f} seq={seq_len} {gather_mode}", flush=True)
if cos_val < 0.999:
print(f" Worst heads: {worst_heads.tolist()} cos={[f'{c:.4f}' for c in per_head_cos[worst_heads].tolist()]}")
cos_list = [f'{c:.4f}' for c in per_head_cos[worst_heads].tolist()]
mag_list = [f'{r:.4f}' for r in per_head_mag_ratio[worst_mag].tolist()]
print(f" Worst heads (cos): {worst_heads.tolist()} cos={cos_list}")
print(f" Worst heads (mag): {worst_mag.tolist()} ratio={mag_list}")
print(f" Mag ratio range: [{per_head_mag_ratio.min().item():.4f}, {per_head_mag_ratio.max().item():.4f}]")
# ---- Continue through the rest of the layer (so subsequent layers get correct X) ----
# Apply inverse RoPE to production output