From 4418e04a28fb37a8071deecc4fba969334e41b1a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 22:28:11 +0000 Subject: [PATCH] D1: revert per-row LSE to sfw_idx=0 for now (debugging D2 regression) --- dsv4/kernels/attention/fmha.py | 4 +++- tests/unit/test_d2_multihead.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 1aaa5b51..44178667 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -507,13 +507,15 @@ class FmhaKernel: # When normalize=True, LSE is not needed (in-kernel normalization). # Each thread writes its row's LSE. With 128 softmax threads and 128 rows, # each thread (sfw_idx) owns exactly one row. + # mLSE shape is (T, 1, 1). mLSE[i, 0, 0] writes row i's LSE. if const_expr(not self.normalize): _row_max_safe = row_max if row_max == -cutlass.Float32.inf: _row_max_safe = Float32(0.0) _ln2 = Float32(0.6931471805599453) # ln(2) lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - mLSE[sfw_idx] = lse_val + if sfw_idx == 0: + mLSE[0] = lse_val tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) diff --git a/tests/unit/test_d2_multihead.py b/tests/unit/test_d2_multihead.py index fecea619..b0096b97 100644 --- a/tests/unit/test_d2_multihead.py +++ b/tests/unit/test_d2_multihead.py @@ -66,7 +66,7 @@ def test_multihead(hd=64, n_h=1, batch=1, T=128, s_k=128): v_tile = v_kernel[:, 0:pv_n_tile].contiguous() v_k = v_tile.unsqueeze(-1) c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') - lse_t = torch.zeros(T, dtype=torch.float32, device='cuda') + lse_t = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda') mQ = ct.from_dlpack(q_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_kernel)) mK = ct.from_dlpack(k_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_kernel))