Guard sO_acc allocation/zero-init with n_kv_tiles>1

This commit is contained in:
2026-05-27 05:05:01 +00:00
parent 101840c78c
commit b43ffe9dac

View File

@@ -186,13 +186,20 @@ class FmhaKernel:
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)
# SMEM accumulator: FP32 [128, pv_n_tile] row-major
# Only allocate for multi-KV-tile (n_kv_tiles > 1)
# For n_kv_tiles=1, epilogue_tma_store reads from TMEM directly.
sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1))
sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=sO_acc_layout, byte_alignment=128)
if const_expr(self.n_kv_tiles > 1):
sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=sO_acc_layout, byte_alignment=128)
else:
# Placeholder — never accessed
sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128)
# Zero-initialize sO_acc
# Zero-initialize sO_acc (only for multi-KV-tile)
if warp_idx < self.mma_warp_id:
for i in cutlass.range(0, cute.size(sO_acc), unroll=1):
sO_acc[i] = Float32(0.0)
if const_expr(self.n_kv_tiles > 1):
for i in cutlass.range(0, cute.size(sO_acc), unroll=1):
sO_acc[i] = Float32(0.0)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
@@ -452,21 +459,13 @@ class FmhaKernel:
si_handle.release()
softmax_done_bar.arrive()
# --- Wait for PV[kt] to complete ---
pv_done_bar.arrive_and_wait()
# ========================================================
# SMEM ACCUMULATOR: load O_kt from TMEM, accumulate in SMEM
# ========================================================
# Only needed for n_kv_tiles > 1.
# For n_kv_tiles=1, epilogue_tma_store reads directly from TMEM.
#
# WARNING: Ld32x32bOp may be a destructive TMEM read!
# If so, the O load here would consume TMEM data,
# making epilogue_tma_store read garbage.
# Guard with const_expr to avoid this for n_kv_tiles=1.
# ========================================================
# --- Wait for PV[kt] to complete (only if accumulating in SMEM) ---
if const_expr(self.n_kv_tiles > 1):
pv_done_bar.arrive_and_wait()
# ========================================================
# SMEM ACCUMULATOR: load O_kt from TMEM, accumulate in SMEM
# ========================================================
rO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.qk_acc_dtype)
cute.copy(tiled_o_load, tTMEM_LOADtO, rO)
cute.arch.fence_view_async_tmem_load()