D5a: Fix LSE formula - lse = ln(row_sum) + row_max * ln(2)

row_max is in scale_log2 domain, need to convert to natural log domain.
attn_max = row_max * ln(2), so lse = ln(row_sum) + row_max * ln(2).
This commit is contained in:
2026-05-23 21:15:14 +00:00
parent a5061a24b9
commit aff208fb4c

View File

@@ -445,14 +445,17 @@ class FmhaKernel:
c_pipe.producer_tail()
# D5a: Write LSE (log-softmax) when normalize=False
# lse = log(row_sum) + row_max_safe (row_max in scaled domain)
# Only thread 0 of the epilogue warps writes LSE for this tile.
# lse = ln(row_sum) + attn_max
# row_max is in the scale_log2 domain: max(S * scale * log2(e))
# attn_max = row_max * ln(2) (converting log2 domain to natural log domain)
# So lse = ln(row_sum) + row_max * ln(2)
if const_expr(not self.normalize):
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
if sfw_idx == 0:
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
# Write LSE to GMEM: mLSE is a (1,1,1) FP32 tensor
mLSE[0, 0, 0] = lse_val