fix: const_expr for SMEM-P tma_p creation

This commit is contained in:
2026-05-24 03:37:20 +00:00
parent 6f0475f0db
commit 49f54aef2d

View File

@@ -89,7 +89,7 @@ class FmhaKernel:
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream, lse=None, gP=None, tma_p=None):
def __call__(self, q, k, v, c, stream, lse=None, gP=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
@@ -115,16 +115,15 @@ class FmhaKernel:
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
# SMEM-P: gP buffer and TMA for P (GMEM→SMEM via TMA)
if self.use_smem_p and gP is not None:
# gP is passed by the caller when use_smem_p=True
if const_expr(self.use_smem_p):
p_s = cute.slice_(self.p_smem_s,(None,None,None,0))
tma_p,gP = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, pv_mma.thr_id),
gP, p_s, self.qk_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape
)
elif not self.use_smem_p:
tma_p = tma_q # dummy, dead code
else:
raise ValueError("use_smem_p=True but no gP provided")
tma_p = tma_q # dummy, dead code
# Always create a valid mLSE tensor for the kernel.
# CuTeDSL doesn't support None parameters in @cute.kernel.
# For normalize=True, mLSE is unused (dead-code-eliminated by compiler).