D1: const_expr for sP layout selection (CuTeDSL)

This commit is contained in:
2026-05-24 04:05:17 +00:00
parent a945edea79
commit 38c6486fc7

View File

@@ -153,11 +153,11 @@ class FmhaKernel:
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
# sP only needed for SMEM-P path. Save SMEM by allocating tiny buffer for TMEM-P.
# Must use const_expr for the conditional (CuTeDSL traces both branches).
_p_layout = p_smem_s.outer
_p_swizzle = p_smem_s.inner
if not self.use_smem_p:
# sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM)
if const_expr(self.use_smem_p):
_p_layout = p_smem_s.outer
_p_swizzle = p_smem_s.inner
else:
_p_layout = cute.make_layout(((1,1),1,(1,1),1))
_p_swizzle = cute.make_layout(((1,1),1,(1,1),1))
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)