diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 49522eb0..a3476ce0 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -174,9 +174,16 @@ class FmhaKernel: tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) tOrP = tOrP_base[(None,None,None,0)] tCrP = pv_mma.make_fragment_A(sP) - # tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies - # the p0 column offset inline when constructing the gemm arguments. - tOrP0 = tOrP + # tOrP0: apply TMEM column offset for P0 (TMEM-P path only) + # The softmax warps store P at tmem_p0_offset columns. PV MMA must read + # from the same offset. For SMEM-P, tOrP is bound to sP (not TMEM). + if not self.use_smem_p: + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout, + ) + else: + tOrP0 = tOrP tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)