From 0ca7b58a6ac6b7ecf37cedf00c716e18b75bc17a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 22:41:32 +0000 Subject: [PATCH] D1: fully revert LSE change back to original sfw_idx==0 guard --- dsv4/kernels/attention/fmha.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 44178667..f42693d4 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -505,16 +505,13 @@ class FmhaKernel: # Compute LSE: lse = ln(row_sum) + row_max * ln(2) # Only when emitting un-normalized output (D5a path). # When normalize=True, LSE is not needed (in-kernel normalization). - # Each thread writes its row's LSE. With 128 softmax threads and 128 rows, - # each thread (sfw_idx) owns exactly one row. - # mLSE shape is (T, 1, 1). mLSE[i, 0, 0] writes row i's LSE. if const_expr(not self.normalize): _row_max_safe = row_max if row_max == -cutlass.Float32.inf: _row_max_safe = Float32(0.0) - _ln2 = Float32(0.6931471805599453) # ln(2) - lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 if sfw_idx == 0: + _ln2 = Float32(0.6931471805599453) # ln(2) + lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 mLSE[0] = lse_val tmem.relinquish_alloc_permit()