Add mHC Sinkhorn CUDA kernel test

This commit is contained in:
2026-06-02 10:45:02 +00:00
parent 6cb5078821
commit b5f29be169

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
"""Test mHC Sinkhorn CUDA kernel — no fallback.
Verifies:
1. Kernel compiles on sm_100a
2. Output matches PyTorch reference exactly (FP32, n=4, t_max=20)
3. Row sums = 1.0 ± 1e-5 (doubly stochastic)
4. Col sums = 1.0 ± 1e-5 (doubly stochastic)
5. Multiple batch sizes (T=1, 4, 8)
"""
import torch
torch.manual_seed(42)
device = 'cuda'
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("mhc_sinkhorn", ["mhc_sinkhorn.cu"])
print("=" * 60)
print("mHC Sinkhorn CUDA Kernel Test (NO FALLBACK)")
print("=" * 60)
n = 4; t_max = 20; eps = 1e-6
def pytorch_sinkhorn(logits, t_max, eps):
M = torch.softmax(logits, dim=-1) + eps
M = M / (M.sum(dim=-2, keepdim=True) + eps)
for _ in range(t_max - 1):
M = M / (M.sum(dim=-1, keepdim=True) + eps)
M = M / (M.sum(dim=-2, keepdim=True) + eps)
return M
for T in [1, 4, 8, 32]:
logits = torch.randn(T, n, n, device=device, dtype=torch.float32) * 3.0
# CUDA kernel (NO fallback)
out_cuda = mod.mhc_sinkhorn(logits, t_max, eps)
# PyTorch reference
out_ref = pytorch_sinkhorn(logits, t_max, eps)
# Compare
cos = torch.nn.functional.cosine_similarity(out_cuda.flatten(), out_ref.flatten(), dim=0).item()
max_err = (out_cuda - out_ref).abs().max().item()
# Check doubly stochastic
row_sums = out_cuda.sum(dim=-1) # (T, n)
col_sums = out_cuda.sum(dim=-2) # (T, n)
row_err = (row_sums - 1.0).abs().max().item()
col_err = (col_sums - 1.0).abs().max().item()
print(f" T={T:3d}: cos={cos:.8f} max_err={max_err:.2e} row_err={row_err:.2e} col_err={col_err:.2e}")
assert cos >= 0.9999, f"cos={cos:.8f} < 0.9999 at T={T}"
assert row_err < 1e-4, f"row_err={row_err:.2e} — NOT doubly stochastic!"
assert col_err < 1e-4, f"col_err={col_err:.2e} — NOT doubly stochastic!"
print("\n" + "=" * 60)
print("ALL mHC Sinkhorn TESTS PASSED")
print("=" * 60)