diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 1aaa5b51..44178667 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -507,13 +507,15 @@ class FmhaKernel: # When normalize=True, LSE is not needed (in-kernel normalization). # Each thread writes its row's LSE. With 128 softmax threads and 128 rows, # each thread (sfw_idx) owns exactly one row. + # mLSE shape is (T, 1, 1). mLSE[i, 0, 0] writes row i's LSE. if const_expr(not self.normalize): _row_max_safe = row_max if row_max == -cutlass.Float32.inf: _row_max_safe = Float32(0.0) _ln2 = Float32(0.6931471805599453) # ln(2) lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - mLSE[sfw_idx] = lse_val + if sfw_idx == 0: + mLSE[0] = lse_val tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) diff --git a/tests/unit/test_d2_multihead.py b/tests/unit/test_d2_multihead.py index fecea619..b0096b97 100644 --- a/tests/unit/test_d2_multihead.py +++ b/tests/unit/test_d2_multihead.py @@ -66,7 +66,7 @@ def test_multihead(hd=64, n_h=1, batch=1, T=128, s_k=128): v_tile = v_kernel[:, 0:pv_n_tile].contiguous() v_k = v_tile.unsqueeze(-1) c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') - lse_t = torch.zeros(T, dtype=torch.float32, device='cuda') + lse_t = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda') mQ = ct.from_dlpack(q_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_kernel)) mK = ct.from_dlpack(k_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_kernel))