D1.3: Fix CuTeDSL scoping - define tOrP0 unconditionally with p0 offset

This commit is contained in:
2026-05-23 21:01:18 +00:00
parent 0e81fc18aa
commit 47eade4afc

View File

@@ -174,16 +174,12 @@ class FmhaKernel:
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
tOrP = tOrP_base[(None,None,None,0)]
tCrP = pv_mma.make_fragment_A(sP)
# tOrP0: apply TMEM column offset for P0 (TMEM-P path only)
# tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path).
# The softmax warps store P at tmem_p0_offset columns. PV MMA must read
# from the same offset. For SMEM-P, tOrP is bound to sP (not TMEM).
if not self.use_smem_p:
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout,
)
else:
tOrP0 = tOrP
# Must be defined unconditionally (CuTeDSL scoping: no if/else for variables).
p0_col_offset = self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0)
tOrP0 = cute.make_tensor(tOrP.iterator + p0_col_offset, tOrP.layout)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)