From 49f54aef2d269f10ea194768790965096c6441ae Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 03:37:20 +0000 Subject: [PATCH] fix: const_expr for SMEM-P tma_p creation --- dsv4/kernels/attention/fmha.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index ef015e40..2995d6f2 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -89,7 +89,7 @@ class FmhaKernel: cute.size_in_bytes(self.q_dtype, v_s)) * cta @cute.jit - def __call__(self, q, k, v, c, stream, lse=None, gP=None, tma_p=None): + def __call__(self, q, k, v, c, stream, lse=None, gP=None): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() @@ -115,16 +115,15 @@ class FmhaKernel: tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) # SMEM-P: gP buffer and TMA for P (GMEM→SMEM via TMA) - if self.use_smem_p and gP is not None: + # gP is passed by the caller when use_smem_p=True + if const_expr(self.use_smem_p): p_s = cute.slice_(self.p_smem_s,(None,None,None,0)) tma_p,gP = cute.nvgpu.make_tiled_tma_atom_A( utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, pv_mma.thr_id), gP, p_s, self.qk_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape ) - elif not self.use_smem_p: - tma_p = tma_q # dummy, dead code else: - raise ValueError("use_smem_p=True but no gP provided") + tma_p = tma_q # dummy, dead code # Always create a valid mLSE tensor for the kernel. # CuTeDSL doesn't support None parameters in @cute.kernel. # For normalize=True, mLSE is unused (dead-code-eliminated by compiler).