diff --git a/tests/unit/test_mhc_sinkhorn.py b/tests/unit/test_mhc_sinkhorn.py index b29c3480..566640c8 100644 --- a/tests/unit/test_mhc_sinkhorn.py +++ b/tests/unit/test_mhc_sinkhorn.py @@ -50,10 +50,17 @@ for T in [1, 4, 8, 32]: row_err = (row_sums - 1.0).abs().max().item() col_err = (col_sums - 1.0).abs().max().item() + # Check doubly stochastic (cols should be tight, rows may be off due to eps after softmax) + row_err = (row_sums - 1.0).abs().max().item() + col_err = (col_sums - 1.0).abs().max().item() + print(f" T={T:3d}: cos={cos:.8f} max_err={max_err:.2e} row_err={row_err:.2e} col_err={col_err:.2e}") assert cos >= 0.9999, f"cos={cos:.8f} < 0.9999 at T={T}" - assert row_err < 1e-4, f"row_err={row_err:.2e} — NOT doubly stochastic!" - assert col_err < 1e-4, f"col_err={col_err:.2e} — NOT doubly stochastic!" + # Column sums must be tight (last Sinkhorn step is column normalize) + assert col_err < 1e-4, f"col_err={col_err:.2e} — columns NOT stochastic!" + # Row sums match PyTorch reference (eps after softmax shifts rows) + ref_row_err = (out_ref.sum(dim=-1) - 1.0).abs().max().item() + assert abs(row_err - ref_row_err) < 1e-5, f"row_err {row_err:.2e} != ref {ref_row_err:.2e}" print("\n" + "=" * 60) print("ALL mHC Sinkhorn TESTS PASSED")