diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 39405779..154ccd1f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -443,7 +443,7 @@ class FmhaKernel: if sfw_idx == 0: _ln2 = Float32(0.6931471805599453) # ln(2) lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - mLSE[0] = lse_val + mLSE[0] = lse_val.to(self.q_dtype) tmem.relinquish_alloc_permit() tmem.free(tmem_ptr)