Simplify: softmax P only, no O rescale/normalize yet
This commit is contained in:
@@ -226,23 +226,7 @@ class FmhaV3RealSoftmax:
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# O normalize setup
|
||||
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
corr_tile_size = 16
|
||||
tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
|
||||
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
|
||||
tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
|
||||
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
|
||||
thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
|
||||
thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
|
||||
tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_i)
|
||||
tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_i)
|
||||
tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_i)
|
||||
# TODO: O normalize setup (add when O rescale is ready)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -276,16 +260,7 @@ class FmhaV3RealSoftmax:
|
||||
acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
# O rescale in TMEM (only for kt > 0)
|
||||
if kt > 0:
|
||||
# Read O, multiply by acc_scale, write back
|
||||
tTMEM_LOAD_OrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO, tTMEM_LOAD_OrO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for j in range(cute.size(tTMEM_LOAD_OrO)):
|
||||
tTMEM_LOAD_OrO[None, j] = tTMEM_LOAD_OrO[None, j] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMEM_LOAD_OrO, tTMEM_STORE_OtO)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
# TODO: O rescale in TMEM (skip for now, test softmax + P only)
|
||||
|
||||
# Pass 2: P = exp2(S * scale_log2 - row_max), accumulate row_sum
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
@@ -306,15 +281,10 @@ class FmhaV3RealSoftmax:
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# Final O normalization: O = O / row_sum
|
||||
if row_sum != Float32(0.0):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
tTMEM_LOAD_OrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO, tTMEM_LOAD_OrO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for j in range(cute.size(tTMEM_LOAD_OrO)):
|
||||
tTMEM_LOAD_OrO[None, j] = tTMEM_LOAD_OrO[None, j] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMEM_LOAD_OrO, tTMEM_STORE_OtO)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
# TODO: enable after basic softmax works
|
||||
# if row_sum != Float32(0.0):
|
||||
# inv_row_sum = Float32(1.0) / row_sum
|
||||
# ...
|
||||
|
||||
# Epilogue: TMEM -> SMEM -> GMEM via TMA store
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
|
||||
Reference in New Issue
Block a user