test: add try/except for SMEM-P coord test

This commit is contained in:
2026-05-24 02:15:07 +00:00
parent 16bade9e10
commit 8010e3dda2

View File

@@ -44,8 +44,21 @@ def test_smem_p_coords():
mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
print("Compiling FmhaKernel (hd=256, SMEM-P, normalize=False)...")
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
compiled(mQ, mK, mV, mC, stream, mLSE)
try:
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
except Exception as e:
print(f"COMPILE FAILED: {e}")
import traceback
traceback.print_exc()
return
print("Running...")
try:
compiled(mQ, mK, mV, mC, stream, mLSE)
except Exception as e:
print(f"RUN FAILED: {e}")
import traceback
traceback.print_exc()
return
torch.cuda.synchronize()
# The kernel writes P to sP using the coordinate-indexed approach