From 84cd636ba94a8572079ac73cbefa67f88ff5f768 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 21 May 2026 17:58:04 +0000 Subject: [PATCH] Stage C fixes: pv_done_bar sync, acc_scale with scale, fastmath=True - Add pv_done_bar (barrier_id=4): MMA signals PV complete, epilogue waits before O rescale (C6) and final normalization (C9) - Fix acc_scale: exp2(scale * (old_max - new_max)) includes the scale_softmax_log2 factor matching CUTLASS FMHA reference - fastmath=True for both exp2 calls (P computation + rescale) - No *0.5 (our scalar row_sum pattern initializes (0,0) not (sum,sum)) --- tests/unit/test_fmha_v3_softmax.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_fmha_v3_softmax.py b/tests/unit/test_fmha_v3_softmax.py index ddd84e8e..da21ced1 100644 --- a/tests/unit/test_fmha_v3_softmax.py +++ b/tests/unit/test_fmha_v3_softmax.py @@ -2,6 +2,12 @@ FMHA v3 + Stage C: QK -> online softmax -> PV with KV-tile interleaving. Stage C: row_max, exp2, O rescale, row_sum, final normalization. FMHA pattern P store preserved from Stage B. + +Fixes applied: +- pv_done_bar (barrier_id=4): MMA signals PV complete, epilogue waits before O rescale (C6, C9) +- acc_scale includes scale_softmax_log2: exp2(scale * (old_max - new_max)) +- fastmath=True for exp2 calls +- No *0.5 (scalar row_sum pattern does not need it) """ import math import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline @@ -112,6 +118,7 @@ class FmhaV3Softmax: kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) + pv_done_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) @@ -199,6 +206,8 @@ class FmhaV3Softmax: pv_mma.set(tcgen05.Field.ACCUMULATE, True) cute.arch.fence_view_async_tmem_store() vh.release() + # Signal PV done - O is now safe for epilogue rescale + pv_done_bar.arrive() acc_pipe.producer_commit(acc_st); acc_st.advance() acc_pipe.producer_tail(acc_st) @@ -274,10 +283,12 @@ class FmhaV3Softmax: row_max_safe = cutlass.Float32(0.0) # --- C5: Compute rescale factor --- - acc_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=False) + acc_scale = cute.math.exp2(scale * (old_row_max - row_max_safe), fastmath=True) # --- C6: Rescale O in TMEM (load O, multiply by acc_scale, store O) --- if kt > 0: + # Wait for previous PV to finish writing O before rescaling + pv_done_bar.arrive_and_wait() tTMrO = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype) for i in range(o_col_tiles): tTMrO_i_ = tTMrO[None, i] @@ -310,7 +321,7 @@ class FmhaV3Softmax: for j in range(frg_cnt): for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True): tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale + minus_row_max_scale - tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=False) + tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) s_vec = tTMEM_LOADrS_frg[None, j].load() rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) @@ -350,8 +361,8 @@ class FmhaV3Softmax: row_sum = row_sum + tile_sum # --- C9: Final normalization via O TMEM rescale --- - # After all KV tiles, O = sum(P_i @ V_i) but unnormalized. - # Load O, multiply by 1/row_sum, store O. Then use identity epilogue. + # Wait for the last PV to finish before touching O + pv_done_bar.arrive_and_wait() inv_row_sum = cutlass.Float32(1.0) / row_sum tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)