D5a: Use tensor indexing for LSE write

This commit is contained in:
2026-05-23 21:13:52 +00:00
parent 7a87c634fb
commit a5061a24b9

View File

@@ -453,7 +453,8 @@ class FmhaKernel:
_row_max_safe = Float32(0.0)
if sfw_idx == 0:
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe
cute.store(lse_val, mLSE[None, None, 0])
# Write LSE to GMEM: mLSE is a (1,1,1) FP32 tensor
mLSE[0, 0, 0] = lse_val
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)