From b43ffe9dace22fb1ee982d3ed5aea3f57348dd23 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 27 May 2026 05:05:01 +0000 Subject: [PATCH] Guard sO_acc allocation/zero-init with n_kv_tiles>1 --- dsv4/kernels/attention/fmha_smem_acc.py | 35 ++++++++++++------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index a127f078..b2f47169 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -186,13 +186,20 @@ class FmhaKernel: sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle) # SMEM accumulator: FP32 [128, pv_n_tile] row-major + # Only allocate for multi-KV-tile (n_kv_tiles > 1) + # For n_kv_tiles=1, epilogue_tma_store reads from TMEM directly. sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1)) - sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=sO_acc_layout, byte_alignment=128) + if const_expr(self.n_kv_tiles > 1): + sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=sO_acc_layout, byte_alignment=128) + else: + # Placeholder — never accessed + sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128) - # Zero-initialize sO_acc + # Zero-initialize sO_acc (only for multi-KV-tile) if warp_idx < self.mma_warp_id: - for i in cutlass.range(0, cute.size(sO_acc), unroll=1): - sO_acc[i] = Float32(0.0) + if const_expr(self.n_kv_tiles > 1): + for i in cutlass.range(0, cute.size(sO_acc), unroll=1): + sO_acc[i] = Float32(0.0) gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) @@ -452,21 +459,13 @@ class FmhaKernel: si_handle.release() softmax_done_bar.arrive() - # --- Wait for PV[kt] to complete --- - pv_done_bar.arrive_and_wait() - - # ======================================================== - # SMEM ACCUMULATOR: load O_kt from TMEM, accumulate in SMEM - # ======================================================== - # Only needed for n_kv_tiles > 1. - # For n_kv_tiles=1, epilogue_tma_store reads directly from TMEM. - # - # WARNING: Ld32x32bOp may be a destructive TMEM read! - # If so, the O load here would consume TMEM data, - # making epilogue_tma_store read garbage. - # Guard with const_expr to avoid this for n_kv_tiles=1. - # ======================================================== + # --- Wait for PV[kt] to complete (only if accumulating in SMEM) --- if const_expr(self.n_kv_tiles > 1): + pv_done_bar.arrive_and_wait() + + # ======================================================== + # SMEM ACCUMULATOR: load O_kt from TMEM, accumulate in SMEM + # ======================================================== rO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.qk_acc_dtype) cute.copy(tiled_o_load, tTMEM_LOADtO, rO) cute.arch.fence_view_async_tmem_load()