From aff208fb4cb9f7f98384d5d2103d57dae1d8df21 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 21:15:14 +0000 Subject: [PATCH] D5a: Fix LSE formula - lse = ln(row_sum) + row_max * ln(2) row_max is in scale_log2 domain, need to convert to natural log domain. attn_max = row_max * ln(2), so lse = ln(row_sum) + row_max * ln(2). --- dsv4/kernels/attention/fmha.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 9ed1e596..5a6e816d 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -445,14 +445,17 @@ class FmhaKernel: c_pipe.producer_tail() # D5a: Write LSE (log-softmax) when normalize=False - # lse = log(row_sum) + row_max_safe (row_max in scaled domain) - # Only thread 0 of the epilogue warps writes LSE for this tile. + # lse = ln(row_sum) + attn_max + # row_max is in the scale_log2 domain: max(S * scale * log2(e)) + # attn_max = row_max * ln(2) (converting log2 domain to natural log domain) + # So lse = ln(row_sum) + row_max * ln(2) if const_expr(not self.normalize): _row_max_safe = row_max if row_max == -cutlass.Float32.inf: _row_max_safe = Float32(0.0) if sfw_idx == 0: - lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe + _ln2 = Float32(0.6931471805599453) # ln(2) + lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 # Write LSE to GMEM: mLSE is a (1,1,1) FP32 tensor mLSE[0, 0, 0] = lse_val