Add mHC Sinkhorn CUDA kernel test
This commit is contained in:
60
tests/unit/test_mhc_sinkhorn.py
Normal file
60
tests/unit/test_mhc_sinkhorn.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test mHC Sinkhorn CUDA kernel — no fallback.
|
||||
|
||||
Verifies:
|
||||
1. Kernel compiles on sm_100a
|
||||
2. Output matches PyTorch reference exactly (FP32, n=4, t_max=20)
|
||||
3. Row sums = 1.0 ± 1e-5 (doubly stochastic)
|
||||
4. Col sums = 1.0 ± 1e-5 (doubly stochastic)
|
||||
5. Multiple batch sizes (T=1, 4, 8)
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = 'cuda'
|
||||
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("mhc_sinkhorn", ["mhc_sinkhorn.cu"])
|
||||
|
||||
print("=" * 60)
|
||||
print("mHC Sinkhorn CUDA Kernel Test (NO FALLBACK)")
|
||||
print("=" * 60)
|
||||
|
||||
n = 4; t_max = 20; eps = 1e-6
|
||||
|
||||
def pytorch_sinkhorn(logits, t_max, eps):
|
||||
M = torch.softmax(logits, dim=-1) + eps
|
||||
M = M / (M.sum(dim=-2, keepdim=True) + eps)
|
||||
for _ in range(t_max - 1):
|
||||
M = M / (M.sum(dim=-1, keepdim=True) + eps)
|
||||
M = M / (M.sum(dim=-2, keepdim=True) + eps)
|
||||
return M
|
||||
|
||||
for T in [1, 4, 8, 32]:
|
||||
logits = torch.randn(T, n, n, device=device, dtype=torch.float32) * 3.0
|
||||
|
||||
# CUDA kernel (NO fallback)
|
||||
out_cuda = mod.mhc_sinkhorn(logits, t_max, eps)
|
||||
|
||||
# PyTorch reference
|
||||
out_ref = pytorch_sinkhorn(logits, t_max, eps)
|
||||
|
||||
# Compare
|
||||
cos = torch.nn.functional.cosine_similarity(out_cuda.flatten(), out_ref.flatten(), dim=0).item()
|
||||
max_err = (out_cuda - out_ref).abs().max().item()
|
||||
|
||||
# Check doubly stochastic
|
||||
row_sums = out_cuda.sum(dim=-1) # (T, n)
|
||||
col_sums = out_cuda.sum(dim=-2) # (T, n)
|
||||
row_err = (row_sums - 1.0).abs().max().item()
|
||||
col_err = (col_sums - 1.0).abs().max().item()
|
||||
|
||||
print(f" T={T:3d}: cos={cos:.8f} max_err={max_err:.2e} row_err={row_err:.2e} col_err={col_err:.2e}")
|
||||
assert cos >= 0.9999, f"cos={cos:.8f} < 0.9999 at T={T}"
|
||||
assert row_err < 1e-4, f"row_err={row_err:.2e} — NOT doubly stochastic!"
|
||||
assert col_err < 1e-4, f"col_err={col_err:.2e} — NOT doubly stochastic!"
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL mHC Sinkhorn TESTS PASSED")
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user