fix: LSE type mismatch Float32→BFloat16

This commit is contained in:
2026-05-24 03:20:24 +00:00
parent 791cd8b9c7
commit fe47a5f882

View File

@@ -443,7 +443,7 @@ class FmhaKernel:
if sfw_idx == 0:
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[0] = lse_val
mLSE[0] = lse_val.to(self.q_dtype)
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)