SMEM-P: add tCrP debug print, reset permute to 0

This commit is contained in:
2026-05-23 20:14:32 +00:00
parent 3bbc9a5a86
commit 6629fe57cf

View File

@@ -24,7 +24,7 @@ class FmhaKernel:
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.debug_p_one = False # DEBUG: write constant P=1.0 to verify mapping
self.debug_swap_mn = False # DEBUG: try swapping m and n0 in coordinate mapping
self.debug_permute = 1 # DEBUG: try different coordinate permutations (0=original,1=swap m↔n0)
self.debug_permute = 0 # DEBUG: try different coordinate permutations (0=original,1=swap m↔n0)
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
@@ -184,6 +184,8 @@ class FmhaKernel:
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
tOrP = tOrP_base[(None,None,None,0)]
tCrP = pv_mma.make_fragment_A(sP)
if self.use_smem_p:
print(f"[SMEM-P DEBUG] tCrP shape: {cute.shape(tCrP)} layout: {tCrP.layout}")
# tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies
# the p0 column offset inline when constructing the gemm arguments.
tOrP0 = tOrP