Fix TMA store: use epi_s view of sC for proper layout compatibility
This commit is contained in:
@@ -522,44 +522,35 @@ class FmhaKernel:
|
||||
if row < Int32(128):
|
||||
sO_acc[row, col] = sO_acc[row, col] * inv_row_sum
|
||||
|
||||
# Copy sO_acc (FP32) -> sC (BF16) using SMEM copy
|
||||
# sC has swizzled layout from compute_epilogue_tile_shape,
|
||||
# but we can write to it using the epi_tile coordinate mapping.
|
||||
# ============================================================
|
||||
# Cast sO_acc (FP32) -> sC (BF16) and TMA store to GMEM
|
||||
# ============================================================
|
||||
# sC has a swizzled layout. We need to write using sC's native
|
||||
# coordinate system. The epi_tile defines the logical tile shape.
|
||||
#
|
||||
# Alternative: use TMA store directly from a properly laid out SMEM buffer.
|
||||
# The simplest correct approach: use epilogue_tma_store but read from
|
||||
# a SMEM buffer instead of TMEM.
|
||||
#
|
||||
# For the MVP, we use the existing sC layout and write via
|
||||
# the epi_tile partition that TMA expects.
|
||||
# Strategy: use epi_s (the TMA-compatible view of sC) to write
|
||||
# sO_acc data into sC, then TMA copy sC -> gC.
|
||||
# ============================================================
|
||||
|
||||
# Use epilogue_tma_store to write sO_acc -> GMEM
|
||||
# But epilogue_tma_store reads from TMEM, not SMEM.
|
||||
# We need a different TMA store path.
|
||||
#
|
||||
# Simplest: use cpasync.bulk_copy (SMEM->GMEM) with sC as source.
|
||||
# First: copy sO_acc -> sC (FP32->BF16 cast)
|
||||
# Then: TMA bulk copy sC -> GMEM
|
||||
#
|
||||
# Write to sC row by row using the epi_tile coordinate mapping.
|
||||
# The epi_tile shape is derived from cta_tile_shape_mnk.
|
||||
# For hd=64 with pv_n_tile=64: epi_tile covers (128, 64).
|
||||
# epi_s is the 2-mode view of sC that tma_c was created from
|
||||
epi_s = cute.select(c_smem_s, mode=[0, 1])
|
||||
sC_view = cute.make_tensor(sC.iterator, epi_s) # TMA-compatible layout
|
||||
|
||||
# For each row assigned to this thread, cast FP32->BF16
|
||||
# and write to sC using flat index mapping.
|
||||
# sC is 2-stage: sC[128, pv_n_tile, num_c_stage] in BF16
|
||||
c_stage0 = cute.slice_(sC, (None, None, 0)) # First stage of sC
|
||||
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
|
||||
row = sfw_idx
|
||||
if row < Int32(128):
|
||||
c_stage0[row, col] = sO_acc[row, col].to(self.o_dtype)
|
||||
# Write sO_acc -> sC using sC_view's coordinate system
|
||||
# sC_view is indexed by epi_tile coordinates
|
||||
# For simple row-major epi_tile: (row, col) works
|
||||
for row in cutlass.range(0, 128, unroll=1):
|
||||
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
|
||||
sC_view[row, col] = sO_acc[row, col].to(self.o_dtype)
|
||||
|
||||
# TMA store sC -> GMEM
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
c_pipe = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_c_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
)
|
||||
c_pipe.producer_acquire()
|
||||
cute.copy(tma_c, c_stage0, tCgC[(None, None, Int32(0))])
|
||||
cute.copy(tma_c, sC_view, tCgC[(None, None, Int32(0))])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_tail()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user