From 7b8ee862bd4d1bf468a22908df02731fb2c658d4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 15:49:48 +0000 Subject: [PATCH] add explicit acc_pipe.consumer_wait before final normalize Race condition: softmax reads O to normalize while MMA may still be writing PV[N-1]. Single-tile wins by luck; multi-tile drifts. Move acc_cons_st construction before the wait so epilogue reuses it. --- tests/fmha_v3_stage_c_example1.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/fmha_v3_stage_c_example1.py b/tests/fmha_v3_stage_c_example1.py index 502caa9a..9296974e 100644 --- a/tests/fmha_v3_stage_c_example1.py +++ b/tests/fmha_v3_stage_c_example1.py @@ -382,14 +382,16 @@ class FmhaV3StageCMulti: si_handle.release() softmax_done_bar.arrive() + # --- Wait for MMA to finish PV[N-1] before reading O --- + # Without this, the final normalize reads O while MMA may still be + # writing PV[N-1]. Single-tile wins the race by luck; multi-tile drifts. + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + acc_pipe.consumer_wait(acc_cons_st) + # --- Final O normalization by row_sum --- # After the last softmax iteration, MMA still needs to finish the final - # PV[N-1]. The acc_pipe consumer wait inside epilogue_tma_store handles - # that for the GMEM write, but we also need O fully accumulated before - # we divide by row_sum here. In practice MMA's PV[N-1] is small enough - # that by the time softmax has fallen out of the loop it has retired. - # If you observe drift, insert an explicit acc_pipe.consumer_wait - # before this block — see the consumer state setup below. + # PV[N-1]. The acc_pipe consumer wait above ensures O is fully accumulated + # before we divide by row_sum here. inv_row_sum = Float32(1.0) / row_sum for i in range(o_col_tiles): tTMEM_LOAD_O_i = cute.make_tensor( @@ -410,7 +412,6 @@ class FmhaV3StageCMulti: # --- Epilogue: TMEM -> SMEM -> GMEM via TMA store --- tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) @@ -450,7 +451,8 @@ def test(): kernel = FmhaV3StageCMulti(s_k=n) print(f'n={n}: Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) - print(f'n={n}: tmem_offsets: s0={kernel.tmem_s0_offset} ' + n_tiles = n // 128 + print(f'n={n}: n_kv_tiles should be {n_tiles}, tmem_offsets: s0={kernel.tmem_s0_offset} ' f'p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} ' f'alloc={kernel.num_tmem_alloc_cols}', flush=True) compiled(mQ, mK, mV, mC, stream)