Use cpasync.tma_partition for SMEM->GMEM TMA store (like epilogue_tma_store)

This commit is contained in:
2026-05-27 04:58:47 +00:00
parent e614d0894c
commit 7ea77a121f

View File

@@ -525,32 +525,43 @@ class FmhaKernel:
# ============================================================
# ============================================================
# EPILOGUE: write sO_acc to GMEM
# EPILOGUE: write sO_acc to sC, TMA store sC -> GMEM
# ============================================================
# Strategy: write sO_acc (FP32) -> sC (BF16) -> TMA store to GMEM.
# The sC write uses flat indexing over the stage-0 slice.
# The TMA store uses the epi_s layout that tma_c was created from.
# Write sO_acc (FP32) -> sC (BF16) using sC's layout indexing.
# Then use cpasync.tma_partition + cute.copy for TMA store.
# ============================================================
# Cast sO_acc (FP32) -> sC (BF16), stage 0
# sC layout from make_smem_layout_epi: complex swizzled layout.
# We write via flat indexing on the stage-0 slice.
# The layout will handle the swizzle automatically.
sC_s0 = sC[(None, None, Int32(0))]
for i in cutlass.range(0, cute.size(sC_s0), unroll=1):
row = i // self.pv_n_tile
col = i % self.pv_n_tile
if row < Int32(128) and col < self.pv_n_tile:
sC_s0[i] = sO_acc[row, col].to(self.o_dtype)
# Step 1: Write sO_acc -> sC (BF16)
# Use flat indexing on sC with stage dimension removed.
# sC has layout from make_smem_layout_epi: ((M, N), ?, num_c_stage, ...).
# We write to stage 0 using the epi_s (2-mode) view.
# Since we can't easily create a tensor with epi_s layout from sC's pointer
# (swizzle conflict), we write to sC via its native 4D layout.
#
# sC shape: ((128, pv_n_tile), 1, num_c_stage, ...)
# Index: sC[(row, col), 0, stage_idx, ...]
for row in cutlass.range(0, 128, unroll=1):
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
sC[(row, col), Int32(0), Int32(0)] = sO_acc[row, col].to(self.o_dtype)
# TMA store sC -> GMEM
# Step 2: TMA store sC -> GMEM
# Use cpasync.tma_partition (same as epilogue_tma_store)
cute.arch.fence_proxy("async.shared", space="cta")
c_pipe = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
)
c_pipe.producer_acquire()
cute.copy(tma_c, sC_s0, tCgC[(None, None, Int32(0))])
# Transform tCgC layout (same as epilogue_tma_store)
tCgC = transform_partitioned_tensor_layout(tCgC)
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
# Create TMA partition from sC and gC
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC, 0, 2),
cute.group_modes(tCgC_epi, 0, 2),
)
cute.copy(tma_c, bSG_sC[None, ...], bSG_gC[None, ...])
c_pipe.producer_commit()
c_pipe.producer_tail()