D1.3: Full correction_epilog with TMA store, normalize in reg before SMEM write
One-way trip: TMEM->reg (normalize) ->SMEM->GMEM Replicates epilogue_tma_store logic with normalize step added Uses CUTLASS helpers for correct layout handling
This commit is contained in:
@@ -370,38 +370,52 @@ class FmhaKernel:
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# === Correction epilog: one-way TMEM -> reg -> SMEM -> GMEM ===
|
||||
# Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) for correct TMEM read.
|
||||
# Uses epilogue_smem_copy_and_partition (get_smem_store_op) for correct SMEM write.
|
||||
# No TMEM round-trip. No layout mismatch. No 3% error.
|
||||
# === Correction epilog: TMEM -> reg (normalize) -> SMEM -> GMEM ===
|
||||
# Full pipeline using CUTLASS epilogue helpers for correct layout handling.
|
||||
# Replaces the broken NO-OP TMEM round-trip + normalize approach.
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
# Set up the TMEM→reg and reg→SMEM copy atoms using CUTLASS helpers
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
# Transform layout: ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, STAGE)
|
||||
# -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), STAGE)
|
||||
tCtO = transform_partitioned_tensor_layout(tCtO_base)
|
||||
# Transform gC layout similarly (needed by the helpers)
|
||||
tCgC_xform = transform_partitioned_tensor_layout(tCgC)
|
||||
|
||||
# TMEM->reg copy (uses get_tmem_load_op for correct layout)
|
||||
tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition(
|
||||
self, sfw_idx, tCtO, tCgC_xform, epi_tile, self.use_2cta_instrs
|
||||
)
|
||||
# reg->SMEM copy (uses get_smem_store_op for correct layout)
|
||||
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
|
||||
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
|
||||
self, tiled_copy_t2r, tTR_rC, sfw_idx, sC
|
||||
)
|
||||
# TMA SMEM->GMEM partition
|
||||
tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
|
||||
# Wait for accumulator buffer
|
||||
acc_pipe.consumer_wait(pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage))
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
acc_pipe.consumer_wait(acc_cons_st)
|
||||
|
||||
# Process each subtile: TMEM load -> normalize -> BF16 convert -> SMEM store
|
||||
# Process subtiles
|
||||
tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||||
bSG_gC_g = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
|
||||
subtile_cnt = cute.size(tTR_tAcc_g.shape, mode=[3])
|
||||
epilog_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)))
|
||||
|
||||
for subtile_idx in range(subtile_cnt):
|
||||
# Load from TMEM
|
||||
tTR_tAcc_mn = tTR_tAcc_g[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
||||
|
||||
# Normalize: O *= 1/row_sum
|
||||
# NORMALIZE: O *= 1/row_sum (the key addition vs. epilogue_tma_store)
|
||||
for j in cutlass.range(cute.size(tTR_rAcc), vectorize=True):
|
||||
tTR_rAcc[j] = tTR_rAcc[j] * inv_row_sum
|
||||
|
||||
@@ -413,28 +427,20 @@ class FmhaKernel:
|
||||
c_buffer = subtile_idx % self.num_c_stage
|
||||
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)])
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
|
||||
# TMA store from SMEM to GMEM
|
||||
# Partition sC and gC for TMA store (using transformed gC)
|
||||
tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
# Only warp 0 of epilogue issues TMA store
|
||||
# TMA store SMEM -> GMEM
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, subtile_idx)])
|
||||
# Sync after TMA store
|
||||
epilog_sync_bar = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
epilog_sync_bar.arrive_and_wait()
|
||||
cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC_g[(None, subtile_idx)])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_acquire()
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
|
||||
# Release accumulator buffer
|
||||
with cute.arch.elect_one():
|
||||
acc_pipe.consumer_release(pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage))
|
||||
acc_pipe.consumer_release(acc_cons_st)
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
Reference in New Issue
Block a user