fix: const_expr for SMEM-P tma_p creation
This commit is contained in:
@@ -89,7 +89,7 @@ class FmhaKernel:
|
||||
cute.size_in_bytes(self.q_dtype, v_s)) * cta
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream, lse=None, gP=None, tma_p=None):
|
||||
def __call__(self, q, k, v, c, stream, lse=None, gP=None):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
@@ -115,16 +115,15 @@ class FmhaKernel:
|
||||
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
|
||||
|
||||
# SMEM-P: gP buffer and TMA for P (GMEM→SMEM via TMA)
|
||||
if self.use_smem_p and gP is not None:
|
||||
# gP is passed by the caller when use_smem_p=True
|
||||
if const_expr(self.use_smem_p):
|
||||
p_s = cute.slice_(self.p_smem_s,(None,None,None,0))
|
||||
tma_p,gP = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
gP, p_s, self.qk_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape
|
||||
)
|
||||
elif not self.use_smem_p:
|
||||
tma_p = tma_q # dummy, dead code
|
||||
else:
|
||||
raise ValueError("use_smem_p=True but no gP provided")
|
||||
tma_p = tma_q # dummy, dead code
|
||||
# Always create a valid mLSE tensor for the kernel.
|
||||
# CuTeDSL doesn't support None parameters in @cute.kernel.
|
||||
# For normalize=True, mLSE is unused (dead-code-eliminated by compiler).
|
||||
|
||||
Reference in New Issue
Block a user