D1.5: Fix TMA store - use local_tile with pv_mma_tiler
This commit is contained in:
@@ -487,15 +487,14 @@ class FmhaKernel:
|
||||
epilog_sync_bar.arrive_and_wait()
|
||||
|
||||
# TMA store: SMEM → GMEM
|
||||
# Use the TMA partition from the kernel setup
|
||||
# Reuse the existing TMA partition (tCgC) which was set up at kernel start.
|
||||
# sC was written by the correction epilog. TMA reads from sC → GMEM via tCgC.
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
c_pipe.producer_acquire()
|
||||
# Use epi_tile-matched GMEM partition
|
||||
gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0)))
|
||||
# TMA store from sC (2D view) to gC (2D view)
|
||||
sC_epi = cute.select(sC, mode=[0, 1]) # 2D view for TMA
|
||||
cute.copy(tma_c, sC_epi, gC_epi)
|
||||
# TMA store from sC to GMEM using the pre-partitioned gC
|
||||
gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler,(None,0,None)),(None,None,None))
|
||||
cute.copy(tma_c, cute.select(sC, mode=[0, 1]), cute.select(gC, mode=[0, 1]))
|
||||
c_pipe.producer_commit()
|
||||
cute.arch.gpu_bar_sync()
|
||||
c_pipe.producer_tail()
|
||||
|
||||
Reference in New Issue
Block a user