diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 44178667..f42693d4 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -505,16 +505,13 @@ class FmhaKernel: # Compute LSE: lse = ln(row_sum) + row_max * ln(2) # Only when emitting un-normalized output (D5a path). # When normalize=True, LSE is not needed (in-kernel normalization). - # Each thread writes its row's LSE. With 128 softmax threads and 128 rows, - # each thread (sfw_idx) owns exactly one row. - # mLSE shape is (T, 1, 1). mLSE[i, 0, 0] writes row i's LSE. if const_expr(not self.normalize): _row_max_safe = row_max if row_max == -cutlass.Float32.inf: _row_max_safe = Float32(0.0) - _ln2 = Float32(0.6931471805599453) # ln(2) - lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 if sfw_idx == 0: + _ln2 = Float32(0.6931471805599453) # ln(2) + lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 mLSE[0] = lse_val tmem.relinquish_alloc_permit()