Fix TMA store: use bSG_sC[(None,0)] indexing pattern from epilogue_tma_store
This commit is contained in:
@@ -547,22 +547,31 @@ class FmhaKernel:
|
||||
# Step 2: TMA store sC -> GMEM
|
||||
# Use cpasync.tma_partition (same as epilogue_tma_store)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
c_pipe = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_c_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
# Sync barrier for SMEM->GMEM ordering
|
||||
epilog_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
c_pipe.producer_acquire()
|
||||
# Transform tCgC layout (same as epilogue_tma_store)
|
||||
tCgC = transform_partitioned_tensor_layout(tCgC)
|
||||
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
|
||||
# Create TMA partition from sC and gC
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
# Transform tCgC layout and partition for TMA
|
||||
tCgC_xfm = transform_partitioned_tensor_layout(tCgC)
|
||||
tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
cute.copy(tma_c, bSG_sC[None, ...], bSG_gC[None, ...])
|
||||
c_pipe.producer_commit()
|
||||
# TMA store: only the first epilogue warp does the copy
|
||||
c_pipe = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_c_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
)
|
||||
c_pipe.producer_acquire()
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, bSG_sC[(None, Int32(0))], bSG_gC[(None, Int32(0))])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_acquire()
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
c_pipe.producer_tail()
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
|
||||
Reference in New Issue
Block a user