From a945edea79aceb8ba5d9b62fb6caa216a8e9a9fa Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 04:04:27 +0000 Subject: [PATCH] D1: Python if for sP layout (trace-time, not MLIR) --- dsv4/kernels/attention/fmha.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 0d112325..3793e69c 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -154,12 +154,13 @@ class FmhaKernel: sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) # sP only needed for SMEM-P path. Save SMEM by allocating tiny buffer for TMEM-P. - if self.use_smem_p: - sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner) - else: - # Tiny placeholder that matches the 4-mode slice pattern - _tiny_p_layout = cute.make_layout(((1,1),1,(1,1),1)) - sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_tiny_p_layout,byte_alignment=16) + # Must use const_expr for the conditional (CuTeDSL traces both branches). + _p_layout = p_smem_s.outer + _p_swizzle = p_smem_s.inner + if not self.use_smem_p: + _p_layout = cute.make_layout(((1,1),1,(1,1),1)) + _p_swizzle = cute.make_layout(((1,1),1,(1,1),1)) + sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle) gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))