Add: O rescale (correction_rescale) in softmax loop + remove pk from TMA/MMA
This commit is contained in:
@@ -224,7 +224,6 @@ class FmhaV3StageCMulti:
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
# ===== MMA warp =====
|
||||
@@ -233,11 +232,11 @@ class FmhaV3StageCMulti:
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
qc.reset(); qh = qc.wait_and_advance(); qh.release()
|
||||
kvc.reset(); pk = kvc.try_wait()
|
||||
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
kvc.reset()
|
||||
for kt in range(n_tiles):
|
||||
kvh = kvc.wait_and_advance() acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_st)
|
||||
for kt in range(n_kv_tiles):
|
||||
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
|
||||
sh = s_prod.acquire_and_advance()
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
|
||||
@@ -290,6 +289,29 @@ class FmhaV3StageCMulti:
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# O rescale setup: same correction_rescale pattern as final normalize.
|
||||
# Uses paired Ld32x32bOp/St32x32bOp atoms with matching Repetition(16).
|
||||
corr_tile_size = 16
|
||||
cO_corr = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO_corr = pv_thr.partition_C(cO_corr)
|
||||
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_i_layout = cute.composition(tOcO_corr.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
|
||||
tOcO_i = cute.make_tensor(tOcO_corr.iterator, tOcO_i_layout)
|
||||
tmem_load_o_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tmem_store_o_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
|
||||
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
|
||||
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
|
||||
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
|
||||
tTMEM_LOAD_OtO = thr_tmem_load_o.partition_S(tOtO_i)
|
||||
tTMEM_LOAD_OcO = thr_tmem_load_o.partition_D(tOcO_i)
|
||||
tTMEM_STORE_OtO = thr_tmem_store_o.partition_D(tOtO_i)
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOAD_OcO.shape, 128 // corr_tile_size), self.acc_dtype)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
@@ -334,6 +356,28 @@ class FmhaV3StageCMulti:
|
||||
acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
# O rescale: multiply existing O by acc_scale = exp2(old_max - new_max)
|
||||
# Uses the correction_rescale pattern (same paired atoms as final normalize).
|
||||
# Must happen BEFORE softmax_done_bar.arrive() so MMA's PV[kt] sees rescaled O.
|
||||
if kt > 0:
|
||||
for ci in range(HEAD_DIM // corr_tile_size):
|
||||
tTMrO_i_ = tTMrO[None, ci]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOAD_OtO_i = cute.make_tensor(
|
||||
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
|
||||
)
|
||||
tTMEM_STORE_OtO_i = cute.make_tensor(
|
||||
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i)
|
||||
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[j] = tTMrO_i[j] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
|
||||
# store BF16 P through the FP32-backed register bridge.
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
|
||||
Reference in New Issue
Block a user