Fix mHC Sinkhorn test: row sums expected to be off (eps after softmax)
This commit is contained in:
@@ -50,10 +50,17 @@ for T in [1, 4, 8, 32]:
|
||||
row_err = (row_sums - 1.0).abs().max().item()
|
||||
col_err = (col_sums - 1.0).abs().max().item()
|
||||
|
||||
# Check doubly stochastic (cols should be tight, rows may be off due to eps after softmax)
|
||||
row_err = (row_sums - 1.0).abs().max().item()
|
||||
col_err = (col_sums - 1.0).abs().max().item()
|
||||
|
||||
print(f" T={T:3d}: cos={cos:.8f} max_err={max_err:.2e} row_err={row_err:.2e} col_err={col_err:.2e}")
|
||||
assert cos >= 0.9999, f"cos={cos:.8f} < 0.9999 at T={T}"
|
||||
assert row_err < 1e-4, f"row_err={row_err:.2e} — NOT doubly stochastic!"
|
||||
assert col_err < 1e-4, f"col_err={col_err:.2e} — NOT doubly stochastic!"
|
||||
# Column sums must be tight (last Sinkhorn step is column normalize)
|
||||
assert col_err < 1e-4, f"col_err={col_err:.2e} — columns NOT stochastic!"
|
||||
# Row sums match PyTorch reference (eps after softmax shifts rows)
|
||||
ref_row_err = (out_ref.sum(dim=-1) - 1.0).abs().max().item()
|
||||
assert abs(row_err - ref_row_err) < 1e-5, f"row_err {row_err:.2e} != ref {ref_row_err:.2e}"
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL mHC Sinkhorn TESTS PASSED")
|
||||
|
||||
Reference in New Issue
Block a user