diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 67f65c77..2eeb3eea 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -445,14 +445,17 @@ class FmhaKernel: c_pipe.producer_tail() # D5a: Write LSE (log-softmax) when normalize=False - # lse = log(row_sum) + row_max (row_max in scaled domain) + # lse = log(row_sum) + row_max_safe (row_max in scaled domain) + # row_max_safe = row_max if row_max != -inf else 0 # Only thread 0 of the epilogue warps writes LSE for this tile. - # For M=1 decode: one lse value per query row. if const_expr(not self.normalize): - if mLSE is not None: - if sfw_idx == 0: - lse_val = cute.math.log(row_sum, fastmath=True) + row_max_safe - mLSE[None, None, 0] = lse_val + # Compute row_max_safe from the final row_max + _row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + _row_max_safe = Float32(0.0) + if sfw_idx == 0: + lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe + mLSE[None, None, 0] = lse_val tmem.relinquish_alloc_permit() tmem.free(tmem_ptr)