D1.5: Fix TMA store rank mismatch - use 2D sC_epi view
This commit is contained in:
@@ -487,10 +487,15 @@ class FmhaKernel:
|
||||
epilog_sync_bar.arrive_and_wait()
|
||||
|
||||
# TMA store: SMEM → GMEM
|
||||
# Use the TMA partition from the kernel setup
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
c_pipe.producer_acquire()
|
||||
cute.copy(tma_c, sC, tCgC)
|
||||
# Use epi_tile-matched GMEM partition
|
||||
gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0)))
|
||||
# TMA store from sC (2D view) to gC (2D view)
|
||||
sC_epi = cute.select(sC, mode=[0, 1]) # 2D view for TMA
|
||||
cute.copy(tma_c, sC_epi, gC_epi)
|
||||
c_pipe.producer_commit()
|
||||
cute.arch.gpu_bar_sync()
|
||||
c_pipe.producer_tail()
|
||||
|
||||
Reference in New Issue
Block a user