D1: revert per-row LSE to sfw_idx=0 for now (debugging D2 regression)

This commit is contained in:
2026-05-24 22:28:11 +00:00
parent 2cc66bff68
commit 4418e04a28
2 changed files with 4 additions and 2 deletions

View File

@@ -507,13 +507,15 @@ class FmhaKernel:
# When normalize=True, LSE is not needed (in-kernel normalization).
# Each thread writes its row's LSE. With 128 softmax threads and 128 rows,
# each thread (sfw_idx) owns exactly one row.
# mLSE shape is (T, 1, 1). mLSE[i, 0, 0] writes row i's LSE.
if const_expr(not self.normalize):
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[sfw_idx] = lse_val
if sfw_idx == 0:
mLSE[0] = lse_val
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)

View File

@@ -66,7 +66,7 @@ def test_multihead(hd=64, n_h=1, batch=1, T=128, s_k=128):
v_tile = v_kernel[:, 0:pv_n_tile].contiguous()
v_k = v_tile.unsqueeze(-1)
c_tile = torch.zeros(T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_t = torch.zeros(T, dtype=torch.float32, device='cuda')
lse_t = torch.zeros(T, 1, 1, dtype=torch.float32, device='cuda')
mQ = ct.from_dlpack(q_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_kernel))
mK = ct.from_dlpack(k_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_kernel))