From 54f1d0d669a803040bcedb2c4dcbffae00aa2690 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 22:30:14 +0000 Subject: [PATCH] D1.3: Fix LSE with const_expr, always create valid mLSE tensor --- dsv4/kernels/attention/fmha.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index bf13d2b3..a2ddef30 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -113,6 +113,9 @@ class FmhaKernel: tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) epi_s = cute.select(self.c_smem_s,mode=[0,1]) tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) + # Always create a valid mLSE tensor (even for normalize=True, it's just unused) + if lse is None: + lse = cute.make_tensor(c.iterator, cute.make_layout((1, 1, 1), stride=(1, 1, 1))) self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) @cute.kernel @@ -468,19 +471,16 @@ class FmhaKernel: ) c_pipe.producer_tail() - # D5a: Write LSE (log-softmax) — always when mLSE is provided + # D5a: Write LSE (log-softmax) when normalize=False # lse = ln(row_sum) + row_max * ln(2) - # This is needed for the SWA+sink merge formula: - # numerator = exp(lse1) * O1_norm + exp(sink) * exp(lse2) * O2_norm - # denominator = exp(lse1) + exp(sink) * exp(lse2) - if mLSE is not None: + # row_max is in scale_log2 domain, multiply by ln(2) to convert. + if const_expr(not self.normalize): _row_max_safe = row_max if row_max == -cutlass.Float32.inf: _row_max_safe = Float32(0.0) if sfw_idx == 0: _ln2 = Float32(0.6931471805599453) # ln(2) lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - # Write LSE to GMEM: mLSE is a (1,1,1) FP32 tensor mLSE[0, 0, 0] = lse_val tmem.relinquish_alloc_permit()