D1.5: Simplify TMA store - use 2D sC_epi and gC_epi views
This commit is contained in:
@@ -490,25 +490,19 @@ class FmhaKernel:
|
||||
|
||||
# Step 6: TMA store SMEM → GMEM
|
||||
# The normalized O is now in sC (written by the correction epilog).
|
||||
# TMA store from sC to the output tensor in GMEM.
|
||||
# Use the pre-partitioned tCgC (GMEM partition) and sC (SMEM buffer).
|
||||
gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler,(None,0,None)),(None,None,None))
|
||||
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
|
||||
# Use the same TMA store pattern as the CUTLASS FMHA reference.
|
||||
# Partition sC and gC for the bulk TMA copy.
|
||||
sC_epi = cute.select(sC, mode=[0, 1]) # 2D view for TMA
|
||||
gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))) # 2D output tile
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
sC_epi, gC_epi,
|
||||
)
|
||||
# One TMA store for the full output tile
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
c_pipe = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_c_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
)
|
||||
c_pipe.producer_acquire()
|
||||
cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC[(None, 0)])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_tail()
|
||||
cute.copy(tma_c, bSG_sC, bSG_gC)
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
|
||||
# D5a: Write LSE (log-softmax) when normalize=False
|
||||
# lse = ln(row_sum) + row_max * ln(2)
|
||||
|
||||
Reference in New Issue
Block a user