D1.5: Fix TMA store - use local_tile with pv_mma_tiler

This commit is contained in:
2026-05-24 00:32:35 +00:00
parent a6bf31a22e
commit a59d57e4d5

View File

@@ -487,15 +487,14 @@ class FmhaKernel:
epilog_sync_bar.arrive_and_wait()
# TMA store: SMEM → GMEM
# Use the TMA partition from the kernel setup
# Reuse the existing TMA partition (tCgC) which was set up at kernel start.
# sC was written by the correction epilog. TMA reads from sC → GMEM via tCgC.
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
c_pipe.producer_acquire()
# Use epi_tile-matched GMEM partition
gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0)))
# TMA store from sC (2D view) to gC (2D view)
sC_epi = cute.select(sC, mode=[0, 1]) # 2D view for TMA
cute.copy(tma_c, sC_epi, gC_epi)
# TMA store from sC to GMEM using the pre-partitioned gC
gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler,(None,0,None)),(None,None,None))
cute.copy(tma_c, cute.select(sC, mode=[0, 1]), cute.select(gC, mode=[0, 1]))
c_pipe.producer_commit()
cute.arch.gpu_bar_sync()
c_pipe.producer_tail()