From cba41d500c9d804dec28f2ae2d90fb9144caa75d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 21:00:29 +0000 Subject: [PATCH] D1.3: Fix critical bug - add TMEM column offset for P0 in PV GEMM The softmax warps store P at tmem_p0_offset=32. PV MMA must read from the same offset. tOrP0 was missing the offset, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests. --- dsv4/kernels/attention/fmha.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 49522eb0..a3476ce0 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -174,9 +174,16 @@ class FmhaKernel: tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) tOrP = tOrP_base[(None,None,None,0)] tCrP = pv_mma.make_fragment_A(sP) - # tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies - # the p0 column offset inline when constructing the gemm arguments. - tOrP0 = tOrP + # tOrP0: apply TMEM column offset for P0 (TMEM-P path only) + # The softmax warps store P at tmem_p0_offset columns. PV MMA must read + # from the same offset. For SMEM-P, tOrP is bound to sP (not TMEM). + if not self.use_smem_p: + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout, + ) + else: + tOrP0 = tOrP tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)