Fix mHC Sinkhorn test: row sums expected to be off (eps after softmax)

This commit is contained in:
2026-06-02 10:46:28 +00:00
parent b5f29be169
commit e231b98387

View File

@@ -50,10 +50,17 @@ for T in [1, 4, 8, 32]:
row_err = (row_sums - 1.0).abs().max().item()
col_err = (col_sums - 1.0).abs().max().item()
# Check doubly stochastic (cols should be tight, rows may be off due to eps after softmax)
row_err = (row_sums - 1.0).abs().max().item()
col_err = (col_sums - 1.0).abs().max().item()
print(f" T={T:3d}: cos={cos:.8f} max_err={max_err:.2e} row_err={row_err:.2e} col_err={col_err:.2e}")
assert cos >= 0.9999, f"cos={cos:.8f} < 0.9999 at T={T}"
assert row_err < 1e-4, f"row_err={row_err:.2e} — NOT doubly stochastic!"
assert col_err < 1e-4, f"col_err={col_err:.2e}NOT doubly stochastic!"
# Column sums must be tight (last Sinkhorn step is column normalize)
assert col_err < 1e-4, f"col_err={col_err:.2e}columns NOT stochastic!"
# Row sums match PyTorch reference (eps after softmax shifts rows)
ref_row_err = (out_ref.sum(dim=-1) - 1.0).abs().max().item()
assert abs(row_err - ref_row_err) < 1e-5, f"row_err {row_err:.2e} != ref {ref_row_err:.2e}"
print("\n" + "=" * 60)
print("ALL mHC Sinkhorn TESTS PASSED")