SMEM-P: add shape debug prints

This commit is contained in:
2026-05-23 20:02:32 +00:00
parent d008505330
commit 8ca07722ae

View File

@@ -324,7 +324,10 @@ class FmhaKernel:
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, frg_layout)
if self.use_smem_p:
print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}")
print(f"[SMEM-P CUTLASS] tTMEM_LOADrS shape: {cute.shape(tTMEM_LOADrS)}")
print(f"[SMEM-P CUTLASS] tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
print(f"[SMEM-P CUTLASS] frg_tile_size: {frg_tile_size}, frg_layout: {frg_layout}")
print(f"[SMEM-P CUTLASS] tTMEM_LOADrS_frg shape: {cute.shape(tTMEM_LOADrS_frg)}")
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):