diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 699f6d47..23142913 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -24,7 +24,7 @@ class FmhaKernel: self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.debug_p_one = False # DEBUG: write constant P=1.0 to verify mapping self.debug_swap_mn = False # DEBUG: try swapping m and n0 in coordinate mapping - self.debug_permute = 1 # DEBUG: try different coordinate permutations (0=original,1=swap m↔n0) + self.debug_permute = 0 # DEBUG: try different coordinate permutations (0=original,1=swap m↔n0) self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 @@ -184,6 +184,8 @@ class FmhaKernel: tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) tOrP = tOrP_base[(None,None,None,0)] tCrP = pv_mma.make_fragment_A(sP) + if self.use_smem_p: + print(f"[SMEM-P DEBUG] tCrP shape: {cute.shape(tCrP)} layout: {tCrP.layout}") # tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies # the p0 column offset inline when constructing the gemm arguments. tOrP0 = tOrP