D1: fully revert LSE change back to original sfw_idx==0 guard

This commit is contained in:
2026-05-24 22:41:32 +00:00
parent db353ec35a
commit 0ca7b58a6a

View File

@@ -505,16 +505,13 @@ class FmhaKernel:
# Compute LSE: lse = ln(row_sum) + row_max * ln(2)
# Only when emitting un-normalized output (D5a path).
# When normalize=True, LSE is not needed (in-kernel normalization).
# Each thread writes its row's LSE. With 128 softmax threads and 128 rows,
# each thread (sfw_idx) owns exactly one row.
# mLSE shape is (T, 1, 1). mLSE[i, 0, 0] writes row i's LSE.
if const_expr(not self.normalize):
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
if sfw_idx == 0:
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[0] = lse_val
tmem.relinquish_alloc_permit()