Guard sO_acc allocation/zero-init with n_kv_tiles>1
This commit is contained in:
@@ -186,13 +186,20 @@ class FmhaKernel:
|
||||
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)
|
||||
|
||||
# SMEM accumulator: FP32 [128, pv_n_tile] row-major
|
||||
# Only allocate for multi-KV-tile (n_kv_tiles > 1)
|
||||
# For n_kv_tiles=1, epilogue_tma_store reads from TMEM directly.
|
||||
sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1))
|
||||
sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=sO_acc_layout, byte_alignment=128)
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=sO_acc_layout, byte_alignment=128)
|
||||
else:
|
||||
# Placeholder — never accessed
|
||||
sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128)
|
||||
|
||||
# Zero-initialize sO_acc
|
||||
# Zero-initialize sO_acc (only for multi-KV-tile)
|
||||
if warp_idx < self.mma_warp_id:
|
||||
for i in cutlass.range(0, cute.size(sO_acc), unroll=1):
|
||||
sO_acc[i] = Float32(0.0)
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
for i in cutlass.range(0, cute.size(sO_acc), unroll=1):
|
||||
sO_acc[i] = Float32(0.0)
|
||||
|
||||
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
|
||||
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
|
||||
@@ -452,21 +459,13 @@ class FmhaKernel:
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# --- Wait for PV[kt] to complete ---
|
||||
pv_done_bar.arrive_and_wait()
|
||||
|
||||
# ========================================================
|
||||
# SMEM ACCUMULATOR: load O_kt from TMEM, accumulate in SMEM
|
||||
# ========================================================
|
||||
# Only needed for n_kv_tiles > 1.
|
||||
# For n_kv_tiles=1, epilogue_tma_store reads directly from TMEM.
|
||||
#
|
||||
# WARNING: Ld32x32bOp may be a destructive TMEM read!
|
||||
# If so, the O load here would consume TMEM data,
|
||||
# making epilogue_tma_store read garbage.
|
||||
# Guard with const_expr to avoid this for n_kv_tiles=1.
|
||||
# ========================================================
|
||||
# --- Wait for PV[kt] to complete (only if accumulating in SMEM) ---
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
pv_done_bar.arrive_and_wait()
|
||||
|
||||
# ========================================================
|
||||
# SMEM ACCUMULATOR: load O_kt from TMEM, accumulate in SMEM
|
||||
# ========================================================
|
||||
rO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_o_load, tTMEM_LOADtO, rO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
Reference in New Issue
Block a user