Stage C fixes: pv_done_bar sync, acc_scale with scale, fastmath=True
- Add pv_done_bar (barrier_id=4): MMA signals PV complete, epilogue waits before O rescale (C6) and final normalization (C9) - Fix acc_scale: exp2(scale * (old_max - new_max)) includes the scale_softmax_log2 factor matching CUTLASS FMHA reference - fastmath=True for both exp2 calls (P computation + rescale) - No *0.5 (our scalar row_sum pattern initializes (0,0) not (sum,sum))
This commit is contained in:
@@ -2,6 +2,12 @@
|
||||
FMHA v3 + Stage C: QK -> online softmax -> PV with KV-tile interleaving.
|
||||
Stage C: row_max, exp2, O rescale, row_sum, final normalization.
|
||||
FMHA pattern P store preserved from Stage B.
|
||||
|
||||
Fixes applied:
|
||||
- pv_done_bar (barrier_id=4): MMA signals PV complete, epilogue waits before O rescale (C6, C9)
|
||||
- acc_scale includes scale_softmax_log2: exp2(scale * (old_max - new_max))
|
||||
- fastmath=True for exp2 calls
|
||||
- No *0.5 (scalar row_sum pattern does not need it)
|
||||
"""
|
||||
import math
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
@@ -112,6 +118,7 @@ class FmhaV3Softmax:
|
||||
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
|
||||
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
pv_done_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
@@ -199,6 +206,8 @@ class FmhaV3Softmax:
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
vh.release()
|
||||
# Signal PV done - O is now safe for epilogue rescale
|
||||
pv_done_bar.arrive()
|
||||
acc_pipe.producer_commit(acc_st); acc_st.advance()
|
||||
acc_pipe.producer_tail(acc_st)
|
||||
|
||||
@@ -274,10 +283,12 @@ class FmhaV3Softmax:
|
||||
row_max_safe = cutlass.Float32(0.0)
|
||||
|
||||
# --- C5: Compute rescale factor ---
|
||||
acc_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=False)
|
||||
acc_scale = cute.math.exp2(scale * (old_row_max - row_max_safe), fastmath=True)
|
||||
|
||||
# --- C6: Rescale O in TMEM (load O, multiply by acc_scale, store O) ---
|
||||
if kt > 0:
|
||||
# Wait for previous PV to finish writing O before rescaling
|
||||
pv_done_bar.arrive_and_wait()
|
||||
tTMrO = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
|
||||
for i in range(o_col_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
@@ -310,7 +321,7 @@ class FmhaV3Softmax:
|
||||
for j in range(frg_cnt):
|
||||
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale + minus_row_max_scale
|
||||
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=False)
|
||||
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
@@ -350,8 +361,8 @@ class FmhaV3Softmax:
|
||||
row_sum = row_sum + tile_sum
|
||||
|
||||
# --- C9: Final normalization via O TMEM rescale ---
|
||||
# After all KV tiles, O = sum(P_i @ V_i) but unnormalized.
|
||||
# Load O, multiply by 1/row_sum, store O. Then use identity epilogue.
|
||||
# Wait for the last PV to finish before touching O
|
||||
pv_done_bar.arrive_and_wait()
|
||||
inv_row_sum = cutlass.Float32(1.0) / row_sum
|
||||
|
||||
tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
|
||||
|
||||
Reference in New Issue
Block a user