test: add try/except for SMEM-P coord test
This commit is contained in:
@@ -44,8 +44,21 @@ def test_smem_p_coords():
|
||||
mLSE = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse))
|
||||
|
||||
print("Compiling FmhaKernel (hd=256, SMEM-P, normalize=False)...")
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE)
|
||||
try:
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
|
||||
except Exception as e:
|
||||
print(f"COMPILE FAILED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
print("Running...")
|
||||
try:
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE)
|
||||
except Exception as e:
|
||||
print(f"RUN FAILED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# The kernel writes P to sP using the coordinate-indexed approach
|
||||
|
||||
Reference in New Issue
Block a user