D5a: Use cute.store for LSE write
This commit is contained in:
@@ -446,16 +446,14 @@ class FmhaKernel:
|
||||
|
||||
# D5a: Write LSE (log-softmax) when normalize=False
|
||||
# lse = log(row_sum) + row_max_safe (row_max in scaled domain)
|
||||
# row_max_safe = row_max if row_max != -inf else 0
|
||||
# Only thread 0 of the epilogue warps writes LSE for this tile.
|
||||
if const_expr(not self.normalize):
|
||||
# Compute row_max_safe from the final row_max
|
||||
_row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf:
|
||||
_row_max_safe = Float32(0.0)
|
||||
if sfw_idx == 0:
|
||||
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe
|
||||
mLSE[None, None, 0] = lse_val
|
||||
cute.store(lse_val, mLSE[None, None, 0])
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
Reference in New Issue
Block a user