add explicit acc_pipe.consumer_wait before final normalize
Race condition: softmax reads O to normalize while MMA may still be writing PV[N-1]. Single-tile wins by luck; multi-tile drifts. Move acc_cons_st construction before the wait so epilogue reuses it.
This commit is contained in:
@@ -382,14 +382,16 @@ class FmhaV3StageCMulti:
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# --- Wait for MMA to finish PV[N-1] before reading O ---
|
||||
# Without this, the final normalize reads O while MMA may still be
|
||||
# writing PV[N-1]. Single-tile wins the race by luck; multi-tile drifts.
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
acc_pipe.consumer_wait(acc_cons_st)
|
||||
|
||||
# --- Final O normalization by row_sum ---
|
||||
# After the last softmax iteration, MMA still needs to finish the final
|
||||
# PV[N-1]. The acc_pipe consumer wait inside epilogue_tma_store handles
|
||||
# that for the GMEM write, but we also need O fully accumulated before
|
||||
# we divide by row_sum here. In practice MMA's PV[N-1] is small enough
|
||||
# that by the time softmax has fallen out of the loop it has retired.
|
||||
# If you observe drift, insert an explicit acc_pipe.consumer_wait
|
||||
# before this block — see the consumer state setup below.
|
||||
# PV[N-1]. The acc_pipe consumer wait above ensures O is fully accumulated
|
||||
# before we divide by row_sum here.
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
for i in range(o_col_tiles):
|
||||
tTMEM_LOAD_O_i = cute.make_tensor(
|
||||
@@ -410,7 +412,6 @@ class FmhaV3StageCMulti:
|
||||
|
||||
# --- Epilogue: TMEM -> SMEM -> GMEM via TMA store ---
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
@@ -450,7 +451,8 @@ def test():
|
||||
kernel = FmhaV3StageCMulti(s_k=n)
|
||||
print(f'n={n}: Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print(f'n={n}: tmem_offsets: s0={kernel.tmem_s0_offset} '
|
||||
n_tiles = n // 128
|
||||
print(f'n={n}: n_kv_tiles should be {n_tiles}, tmem_offsets: s0={kernel.tmem_s0_offset} '
|
||||
f'p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} '
|
||||
f'alloc={kernel.num_tmem_alloc_cols}', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
|
||||
Reference in New Issue
Block a user