D5a: Fix LSE - compute row_max_safe from final row_max, remove mLSE None check

This commit is contained in:
2026-05-23 21:12:29 +00:00
parent a331c8816a
commit ba288cfbd3

View File

@@ -445,14 +445,17 @@ class FmhaKernel:
c_pipe.producer_tail()
# D5a: Write LSE (log-softmax) when normalize=False
# lse = log(row_sum) + row_max (row_max in scaled domain)
# lse = log(row_sum) + row_max_safe (row_max in scaled domain)
# row_max_safe = row_max if row_max != -inf else 0
# Only thread 0 of the epilogue warps writes LSE for this tile.
# For M=1 decode: one lse value per query row.
if const_expr(not self.normalize):
if mLSE is not None:
if sfw_idx == 0:
lse_val = cute.math.log(row_sum, fastmath=True) + row_max_safe
mLSE[None, None, 0] = lse_val
# Compute row_max_safe from the final row_max
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
if sfw_idx == 0:
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe
mLSE[None, None, 0] = lse_val
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)