D1.1: Fix make_fragment_A — use sP for SMEM source pv_mma

This commit is contained in:
2026-05-23 06:04:44 +00:00
parent 80434d0284
commit df43c3232d

View File

@@ -166,20 +166,19 @@ class FmhaKernel:
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# PV A-operand: constructed based on use_smem_p (compile-time branch)
if not self.use_smem_p:
# TMEM-P: PV reads P from TMEM alias of QK C-fragment
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCrP = pv_mma.make_fragment_A(sP) # dummy, never used
else:
# SMEM-P: PV reads P from SMEM
tOrP0 = cute.make_tensor(tStS.iterator, tStS.layout) # dummy, never used
tCrP = pv_mma.make_fragment_A(sP)
# PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally
# When pv_mma uses TMEM source, make_fragment_A needs a TMEM-based tensor (tP from tStS).
# When pv_mma uses SMEM source, make_fragment_A needs an SMEM-based tensor (sP).
# We construct both paths using the appropriate tensor for make_fragment_A.
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
# For TMEM source PV: fragment_A from TMEM tensor tP
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0),
tOrP.layout)
# For SMEM source PV: fragment_A from SMEM tensor sP
tCrP = pv_mma.make_fragment_A(sP)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)