From 03ad730a9b32ef2e71f833d3fd13249eb026963a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 04:02:02 +0000 Subject: [PATCH] D1: Conditional sP allocation (saves 64KB SMEM for TMEM-P at hd=256) --- dsv4/kernels/attention/fmha.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 3b99fda3..774b6882 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -153,7 +153,10 @@ class FmhaKernel: sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) - sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner) if self.use_smem_p else smem.allocate_tensor(element_type=self.q_dtype,layout=cute.make_layout((1,1)),byte_alignment=16) + # sP only needed for SMEM-P path. Save SMEM by allocating tiny buffer for TMEM-P. + _p_alloc_layout = p_smem_s.outer if self.use_smem_p else cute.make_layout((1,)) + _p_alloc_swizzle = p_smem_s.inner if self.use_smem_p else cute.make_layout((1,)) + sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_alloc_layout,byte_alignment=128,swizzle=_p_alloc_swizzle) gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) @@ -186,7 +189,8 @@ class FmhaKernel: tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) tOrP = tOrP_base[(None,None,None,0)] - tCrP = pv_mma.make_fragment_A(sP) + # tCrP is only used in SMEM-P path. Define unconditionally for CuTeDSL scoping. + tCrP = pv_mma.make_fragment_A(sP) if self.use_smem_p else pv_mma.make_fragment_A(tP) # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). # self.tOrP0_offset is pre-computed in _setup as a Python int. # Use const_expr if/else for compile-time conditional.