fix: LSE type mismatch Float32→BFloat16
This commit is contained in:
@@ -443,7 +443,7 @@ class FmhaKernel:
|
||||
if sfw_idx == 0:
|
||||
_ln2 = Float32(0.6931471805599453) # ln(2)
|
||||
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
|
||||
mLSE[0] = lse_val
|
||||
mLSE[0] = lse_val.to(self.q_dtype)
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
Reference in New Issue
Block a user