From 7477253eabb4bbb8c548ced527c33d40f93fd048 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 00:16:22 +0000 Subject: [PATCH] D1.3: Fix LSE tensor layout for weakly congruent store --- dsv4/kernels/attention/fmha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 392ab3a3..2fd7a053 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -117,7 +117,7 @@ class FmhaKernel: # CuTeDSL doesn't support None parameters in @cute.kernel. # For normalize=True, mLSE is unused (dead-code-eliminated by compiler). if const_expr(lse is None): - lse = cute.make_tensor(c.iterator, cute.make_layout((1, 1, 1), stride=(1, 1, 1))) + lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) @cute.kernel @@ -469,7 +469,7 @@ class FmhaKernel: if sfw_idx == 0: _ln2 = Float32(0.6931471805599453) # ln(2) lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - mLSE[0, 0, 0] = lse_val + mLSE[0] = lse_val tmem.relinquish_alloc_permit() tmem.free(tmem_ptr)