Use cpasync.tma_partition for SMEM->GMEM TMA store (like epilogue_tma_store)
This commit is contained in:
@@ -525,32 +525,43 @@ class FmhaKernel:
|
||||
# ============================================================
|
||||
|
||||
# ============================================================
|
||||
# EPILOGUE: write sO_acc to GMEM
|
||||
# EPILOGUE: write sO_acc to sC, TMA store sC -> GMEM
|
||||
# ============================================================
|
||||
# Strategy: write sO_acc (FP32) -> sC (BF16) -> TMA store to GMEM.
|
||||
# The sC write uses flat indexing over the stage-0 slice.
|
||||
# The TMA store uses the epi_s layout that tma_c was created from.
|
||||
# Write sO_acc (FP32) -> sC (BF16) using sC's layout indexing.
|
||||
# Then use cpasync.tma_partition + cute.copy for TMA store.
|
||||
# ============================================================
|
||||
|
||||
# Cast sO_acc (FP32) -> sC (BF16), stage 0
|
||||
# sC layout from make_smem_layout_epi: complex swizzled layout.
|
||||
# We write via flat indexing on the stage-0 slice.
|
||||
# The layout will handle the swizzle automatically.
|
||||
sC_s0 = sC[(None, None, Int32(0))]
|
||||
for i in cutlass.range(0, cute.size(sC_s0), unroll=1):
|
||||
row = i // self.pv_n_tile
|
||||
col = i % self.pv_n_tile
|
||||
if row < Int32(128) and col < self.pv_n_tile:
|
||||
sC_s0[i] = sO_acc[row, col].to(self.o_dtype)
|
||||
# Step 1: Write sO_acc -> sC (BF16)
|
||||
# Use flat indexing on sC with stage dimension removed.
|
||||
# sC has layout from make_smem_layout_epi: ((M, N), ?, num_c_stage, ...).
|
||||
# We write to stage 0 using the epi_s (2-mode) view.
|
||||
# Since we can't easily create a tensor with epi_s layout from sC's pointer
|
||||
# (swizzle conflict), we write to sC via its native 4D layout.
|
||||
#
|
||||
# sC shape: ((128, pv_n_tile), 1, num_c_stage, ...)
|
||||
# Index: sC[(row, col), 0, stage_idx, ...]
|
||||
for row in cutlass.range(0, 128, unroll=1):
|
||||
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
|
||||
sC[(row, col), Int32(0), Int32(0)] = sO_acc[row, col].to(self.o_dtype)
|
||||
|
||||
# TMA store sC -> GMEM
|
||||
# Step 2: TMA store sC -> GMEM
|
||||
# Use cpasync.tma_partition (same as epilogue_tma_store)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
c_pipe = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_c_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
)
|
||||
c_pipe.producer_acquire()
|
||||
cute.copy(tma_c, sC_s0, tCgC[(None, None, Int32(0))])
|
||||
# Transform tCgC layout (same as epilogue_tma_store)
|
||||
tCgC = transform_partitioned_tensor_layout(tCgC)
|
||||
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
|
||||
# Create TMA partition from sC and gC
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
cute.copy(tma_c, bSG_sC[None, ...], bSG_gC[None, ...])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_tail()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user