Clean up SMEM acc epilogue: flat indexing sO_acc->sC, TMA store from sC_s0
This commit is contained in:
@@ -523,25 +523,25 @@ class FmhaKernel:
|
||||
sO_acc[row, col] = sO_acc[row, col] * inv_row_sum
|
||||
|
||||
# ============================================================
|
||||
# Cast sO_acc (FP32) -> sC (BF16) and TMA store to GMEM
|
||||
|
||||
# ============================================================
|
||||
# sC has a swizzled layout. We need to write using sC's native
|
||||
# coordinate system. The epi_tile defines the logical tile shape.
|
||||
#
|
||||
# Strategy: use epi_s (the TMA-compatible view of sC) to write
|
||||
# sO_acc data into sC, then TMA copy sC -> gC.
|
||||
# EPILOGUE: write sO_acc to GMEM
|
||||
# ============================================================
|
||||
# Strategy: write sO_acc (FP32) -> sC (BF16) -> TMA store to GMEM.
|
||||
# The sC write uses flat indexing over the stage-0 slice.
|
||||
# The TMA store uses the epi_s layout that tma_c was created from.
|
||||
# ============================================================
|
||||
|
||||
# epi_s is the 2-mode view of sC that tma_c was created from
|
||||
epi_s = cute.select(c_smem_s, mode=[0, 1])
|
||||
sC_view = cute.make_tensor(sC.iterator, epi_s) # TMA-compatible layout
|
||||
|
||||
# Write sO_acc -> sC using sC_view's coordinate system
|
||||
# sC_view is indexed by epi_tile coordinates
|
||||
# For simple row-major epi_tile: (row, col) works
|
||||
for row in cutlass.range(0, 128, unroll=1):
|
||||
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
|
||||
sC_view[row, col] = sO_acc[row, col].to(self.o_dtype)
|
||||
# Cast sO_acc (FP32) -> sC (BF16), stage 0
|
||||
# sC layout from make_smem_layout_epi: complex swizzled layout.
|
||||
# We write via flat indexing on the stage-0 slice.
|
||||
# The layout will handle the swizzle automatically.
|
||||
sC_s0 = sC[(None, None, Int32(0))]
|
||||
for i in cutlass.range(0, cute.size(sC_s0), unroll=1):
|
||||
row = i // self.pv_n_tile
|
||||
col = i % self.pv_n_tile
|
||||
if row < Int32(128) and col < self.pv_n_tile:
|
||||
sC_s0[i] = sO_acc[row, col].to(self.o_dtype)
|
||||
|
||||
# TMA store sC -> GMEM
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
@@ -550,9 +550,9 @@ class FmhaKernel:
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
)
|
||||
c_pipe.producer_acquire()
|
||||
cute.copy(tma_c, sC_view, tCgC[(None, None, Int32(0))])
|
||||
cute.copy(tma_c, sC_s0, tCgC[(None, None, Int32(0))])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_tail()
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
Reference in New Issue
Block a user