From 75fec90eef34812d06c553df18d397cf30a6f4fc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 05:08:10 +0000 Subject: [PATCH] Fix CuTeDSL scoping: unconditionally define tOrP0 and tCrP --- dsv4/kernels/attention/fmha.py | 35 +++++++++++----------------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 9a66f8b8..87ff4ce6 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -251,30 +251,17 @@ class FmhaKernel: tOtO = pv_thr.make_fragment_C(pv_as) tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - # ── TMEM-P path: PV A-operand from TMEM ── - # Define both paths' variables before the branch (CuTeDSL requires variables - # to exist before dynamic control flow, even when the branch is compile-time) - if not use_smem_p: - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tOrP_base = pv_thr.make_fragment_A(tP) - tOrP = tOrP_base[(None, None, None, 0)] - tOrP0 = cute.make_tensor( - tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, - tOrP.layout, - ) - else: - # SMEM-P: PV reads from SMEM, but define tOrP0 as unused dummy - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tOrP_base = pv_thr.make_fragment_A(tP) - tOrP = tOrP_base[(None, None, None, 0)] - tOrP0 = cute.make_tensor(tOrP.iterator, tOrP.layout) - - # ── SMEM-P path: PV A-operand from SMEM ── - if use_smem_p: - tCrP = pv_mma.make_fragment_A(sP) - else: - # TMEM-P: PV reads from TMEM, but define tCrP as unused dummy - tCrP = pv_mma.make_fragment_A(sP) + # ── PV A-operand: always define both tOrP0 (TMEM) and tCrP (SMEM) ── + # CuTeDSL can't propagate variables across if/else regions, so we + # unconditionally compute both and the unused one is dead-code-eliminated. + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout, + ) + tCrP = pv_mma.make_fragment_A(sP) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))