Use sC_flat (non-swizzled epi_s layout) for TMA store from SMEM accumulator

This commit is contained in:
2026-05-27 05:26:50 +00:00
parent 4a2a06f9e1
commit 6fb0e6a417

View File

@@ -224,8 +224,12 @@ class FmhaKernel:
if const_expr(self.use_smem_accumulator):
sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1))
sO_acc = smem.allocate_tensor(element_type=Float32, layout=sO_acc_layout, byte_alignment=128)
# sC_flat: BF16 SMEM buffer with epi_s layout (non-swizzled) for TMA store
# Used to cast sO_acc (FP32) -> BF16 and TMA store to GMEM
sC_flat = smem.allocate_tensor(element_type=self.o_dtype, layout=cute.select(self.c_smem_s, mode=[0, 1]).outer, byte_alignment=128)
else:
sO_acc = smem.allocate_tensor(element_type=Float32, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128)
sC_flat = smem.allocate_tensor(element_type=self.o_dtype, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
@@ -606,40 +610,31 @@ class FmhaKernel:
)
c_pipe.producer_tail()
else:
# Path 2: write sO_acc (FP32) -> sC -> GMEM via TMA
# Follow the CUTLASS FMHA reference pattern:
# 1. Cast sO_acc (FP32) -> sC (BF16) using sC's layout indexing
# 2. Use flat_divide on mC to create gO, then tma_partition, then copy
# Path 2: write sO_acc (FP32) -> sC_flat (BF16) -> TMA store to GMEM
# sC_flat has epi_s layout (same as what tma_c was created from)
# Step 1: Cast sO_acc -> sC (BF16)
# Step 1: Cast sO_acc -> sC_flat (BF16)
for row in cutlass.range(0, 128, unroll=1):
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
val = sO_acc[row, col]
if const_expr(self.normalize):
inv_row_sum = Float32(1.0) / row_sum
val = val * inv_row_sum
sC[(row, col), Int32(0), Int32(0)] = val.to(self.o_dtype)
sC_flat[row, col] = val.to(self.o_dtype)
cute.arch.fence_proxy("async.shared", space="cta")
# Step 2: TMA store sC -> GMEM (CUTLASS FMHA reference pattern)
# Create gO from mC via flat_divide + slice (same as CUTLASS reference)
# Step 2: TMA store sC_flat -> GMEM
gO_qdl = cute.flat_divide(
mC, cute.select(self.pv_mma_tiler, mode=[0, 1])
)
gO = gO_qdl[None, None, None, Int32(0), Int32(0)]
tOsO, tOgO = cpasync.tma_partition(
tOsC, tOgO = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC, 0, 2),
cute.group_modes(sC_flat, 0, 2),
cute.group_modes(gO, 0, 2),
)
# Wait for all epilogue warps to finish writing to sC
epilog_sync_barrier = pipeline.NamedBarrier(
barrier_id=self.epilog_sync_bar_id,
num_threads=32 * len(self.epilogue_warp_id),
)
epilog_sync_barrier.arrive_and_wait()
if warp_idx == self.epilogue_warp_id[0]:
cute.copy(tma_c, tOsO[None, Int32(0)], tOgO[None, Int32(0)])
cute.copy(tma_c, tOsC[None, Int32(0)], tOgO[None, Int32(0)])
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)