From 47eade4afc4abb38153ddcc6506869338bb3190d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 21:01:18 +0000 Subject: [PATCH] D1.3: Fix CuTeDSL scoping - define tOrP0 unconditionally with p0 offset --- dsv4/kernels/attention/fmha.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index a3476ce0..cba1b032 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -174,16 +174,12 @@ class FmhaKernel: tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) tOrP = tOrP_base[(None,None,None,0)] tCrP = pv_mma.make_fragment_A(sP) - # tOrP0: apply TMEM column offset for P0 (TMEM-P path only) + # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). # The softmax warps store P at tmem_p0_offset columns. PV MMA must read # from the same offset. For SMEM-P, tOrP is bound to sP (not TMEM). - if not self.use_smem_p: - tOrP0 = cute.make_tensor( - tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, - tOrP.layout, - ) - else: - tOrP0 = tOrP + # Must be defined unconditionally (CuTeDSL scoping: no if/else for variables). + p0_col_offset = self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0) + tOrP0 = cute.make_tensor(tOrP.iterator + p0_col_offset, tOrP.layout) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)