From bfacef28d328b7c968004fd04c4d1ac086190d68 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 22:30:55 +0000 Subject: [PATCH] D1.3: Use const_expr for lse None check --- dsv4/kernels/attention/fmha.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index a2ddef30..b71eb04b 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -113,8 +113,10 @@ class FmhaKernel: tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) epi_s = cute.select(self.c_smem_s,mode=[0,1]) tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - # Always create a valid mLSE tensor (even for normalize=True, it's just unused) - if lse is None: + # Always create a valid mLSE tensor for the kernel. + # CuTeDSL doesn't support None parameters in @cute.kernel. + # For normalize=True, mLSE is unused (dead-code-eliminated by compiler). + if const_expr(lse is None): lse = cute.make_tensor(c.iterator, cute.make_layout((1, 1, 1), stride=(1, 1, 1))) self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)