D1: Python if for sP layout (trace-time, not MLIR)
This commit is contained in:
@@ -154,12 +154,13 @@ class FmhaKernel:
|
||||
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
|
||||
# sP only needed for SMEM-P path. Save SMEM by allocating tiny buffer for TMEM-P.
|
||||
if self.use_smem_p:
|
||||
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner)
|
||||
else:
|
||||
# Tiny placeholder that matches the 4-mode slice pattern
|
||||
_tiny_p_layout = cute.make_layout(((1,1),1,(1,1),1))
|
||||
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_tiny_p_layout,byte_alignment=16)
|
||||
# Must use const_expr for the conditional (CuTeDSL traces both branches).
|
||||
_p_layout = p_smem_s.outer
|
||||
_p_swizzle = p_smem_s.inner
|
||||
if not self.use_smem_p:
|
||||
_p_layout = cute.make_layout(((1,1),1,(1,1),1))
|
||||
_p_swizzle = cute.make_layout(((1,1),1,(1,1),1))
|
||||
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)
|
||||
|
||||
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
|
||||
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
|
||||
|
||||
Reference in New Issue
Block a user