diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 4cb3682c..dedaa0d1 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -166,14 +166,20 @@ class FmhaKernel: tOtO = pv_thr.make_fragment_C(pv_as) tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - # PV A-operand: always define both TMEM and SMEM paths (CuTeDSL scoping) - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tOrP_base = pv_thr.make_fragment_A(tP) - tOrP = tOrP_base[(None,None,None,0)] - tOrP0 = cute.make_tensor( - tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0), - tOrP.layout) - tCrP = pv_mma.make_fragment_A(sP) + # PV A-operand: constructed based on use_smem_p (compile-time branch) + if not self.use_smem_p: + # TMEM-P: PV reads P from TMEM alias of QK C-fragment + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None,None,None,0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + tCrP = pv_mma.make_fragment_A(sP) # dummy, never used + else: + # SMEM-P: PV reads P from SMEM + tOrP0 = cute.make_tensor(tStS.iterator, tStS.layout) # dummy, never used + tCrP = pv_mma.make_fragment_A(sP) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)