diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 3793e69c..f3d39635 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -153,11 +153,11 @@ class FmhaKernel: sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) - # sP only needed for SMEM-P path. Save SMEM by allocating tiny buffer for TMEM-P. - # Must use const_expr for the conditional (CuTeDSL traces both branches). - _p_layout = p_smem_s.outer - _p_swizzle = p_smem_s.inner - if not self.use_smem_p: + # sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM) + if const_expr(self.use_smem_p): + _p_layout = p_smem_s.outer + _p_swizzle = p_smem_s.inner + else: _p_layout = cute.make_layout(((1,1),1,(1,1),1)) _p_swizzle = cute.make_layout(((1,1),1,(1,1),1)) sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)