diff --git a/tests/unit/test_smem_p_coord.py b/tests/unit/test_smem_p_coord.py index a5c1320f..7f8fbc0d 100644 --- a/tests/unit/test_smem_p_coord.py +++ b/tests/unit/test_smem_p_coord.py @@ -44,8 +44,21 @@ def test_smem_p_coords(): mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse)) print("Compiling FmhaKernel (hd=256, SMEM-P, normalize=False)...") - compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) - compiled(mQ, mK, mV, mC, stream, mLSE) + try: + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + except Exception as e: + print(f"COMPILE FAILED: {e}") + import traceback + traceback.print_exc() + return + print("Running...") + try: + compiled(mQ, mK, mV, mC, stream, mLSE) + except Exception as e: + print(f"RUN FAILED: {e}") + import traceback + traceback.print_exc() + return torch.cuda.synchronize() # The kernel writes P to sP using the coordinate-indexed approach