From df43c3232dfcafb8ea894dcb86e30155c147597a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 06:04:44 +0000 Subject: [PATCH] =?UTF-8?q?D1.1:=20Fix=20make=5Ffragment=5FA=20=E2=80=94?= =?UTF-8?q?=20use=20sP=20for=20SMEM=20source=20pv=5Fmma?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/attention/fmha.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index dedaa0d1..94aa9dca 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -166,20 +166,19 @@ class FmhaKernel: tOtO = pv_thr.make_fragment_C(pv_as) tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - # PV A-operand: constructed based on use_smem_p (compile-time branch) - if not self.use_smem_p: - # TMEM-P: PV reads P from TMEM alias of QK C-fragment - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tOrP_base = pv_thr.make_fragment_A(tP) - tOrP = tOrP_base[(None,None,None,0)] - tOrP0 = cute.make_tensor( - tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, - tOrP.layout) - tCrP = pv_mma.make_fragment_A(sP) # dummy, never used - else: - # SMEM-P: PV reads P from SMEM - tOrP0 = cute.make_tensor(tStS.iterator, tStS.layout) # dummy, never used - tCrP = pv_mma.make_fragment_A(sP) + # PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally + # When pv_mma uses TMEM source, make_fragment_A needs a TMEM-based tensor (tP from tStS). + # When pv_mma uses SMEM source, make_fragment_A needs an SMEM-based tensor (sP). + # We construct both paths using the appropriate tensor for make_fragment_A. + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + # For TMEM source PV: fragment_A from TMEM tensor tP + tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) + tOrP = tOrP_base[(None,None,None,0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0), + tOrP.layout) + # For SMEM source PV: fragment_A from SMEM tensor sP + tCrP = pv_mma.make_fragment_A(sP) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)