diff --git a/tests/unit/test_decode_fmha_layer.py b/tests/unit/test_decode_fmha_layer.py index 96e6dae7..99869977 100644 --- a/tests/unit/test_decode_fmha_layer.py +++ b/tests/unit/test_decode_fmha_layer.py @@ -452,13 +452,18 @@ def main(): mag_prod = o_prod.float().abs().max().item() mag_ref = o_ref.float().abs().max().item() - # Per-head cosine + # Per-head cosine AND magnitude ratio o_prod_h = o_prod.float().squeeze(1) # (n_h, hd) o_ref_h = o_ref.float().squeeze(1) per_head_cos = F.cosine_similarity(o_prod_h, o_ref_h, dim=-1) + per_head_mag_prod = o_prod_h.abs().max(dim=-1).values # (n_h,) + per_head_mag_ref = o_ref_h.abs().max(dim=-1).values # (n_h,) + per_head_mag_ratio = (per_head_mag_prod / (per_head_mag_ref + 1e-8)) # (n_h,) min_head = per_head_cos.min().item() mean_head = per_head_cos.mean().item() worst_heads = per_head_cos.argsort()[:5] + # Find heads with worst magnitude ratio + worst_mag = per_head_mag_ratio.sub(1.0).abs().argsort(descending=True)[:5] results[li] = { 'cos': cos_val, 'mag_prod': mag_prod, 'mag_ref': mag_ref, @@ -471,7 +476,11 @@ def main(): print(f" L{li}: {status} cos={cos_val:.6f} min_head={min_head:.6f} mean_head={mean_head:.6f} " f"|prod|={mag_prod:.4f} |ref|={mag_ref:.4f} seq={seq_len} {gather_mode}", flush=True) if cos_val < 0.999: - print(f" Worst heads: {worst_heads.tolist()} cos={[f'{c:.4f}' for c in per_head_cos[worst_heads].tolist()]}") + cos_list = [f'{c:.4f}' for c in per_head_cos[worst_heads].tolist()] + mag_list = [f'{r:.4f}' for r in per_head_mag_ratio[worst_mag].tolist()] + print(f" Worst heads (cos): {worst_heads.tolist()} cos={cos_list}") + print(f" Worst heads (mag): {worst_mag.tolist()} ratio={mag_list}") + print(f" Mag ratio range: [{per_head_mag_ratio.min().item():.4f}, {per_head_mag_ratio.max().item():.4f}]") # ---- Continue through the rest of the layer (so subsequent layers get correct X) ---- # Apply inverse RoPE to production output