Use sC_flat (non-swizzled epi_s layout) for TMA store from SMEM accumulator
This commit is contained in:
@@ -224,8 +224,12 @@ class FmhaKernel:
|
||||
if const_expr(self.use_smem_accumulator):
|
||||
sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1))
|
||||
sO_acc = smem.allocate_tensor(element_type=Float32, layout=sO_acc_layout, byte_alignment=128)
|
||||
# sC_flat: BF16 SMEM buffer with epi_s layout (non-swizzled) for TMA store
|
||||
# Used to cast sO_acc (FP32) -> BF16 and TMA store to GMEM
|
||||
sC_flat = smem.allocate_tensor(element_type=self.o_dtype, layout=cute.select(self.c_smem_s, mode=[0, 1]).outer, byte_alignment=128)
|
||||
else:
|
||||
sO_acc = smem.allocate_tensor(element_type=Float32, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128)
|
||||
sC_flat = smem.allocate_tensor(element_type=self.o_dtype, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128)
|
||||
|
||||
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
|
||||
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
|
||||
@@ -606,40 +610,31 @@ class FmhaKernel:
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
else:
|
||||
# Path 2: write sO_acc (FP32) -> sC -> GMEM via TMA
|
||||
# Follow the CUTLASS FMHA reference pattern:
|
||||
# 1. Cast sO_acc (FP32) -> sC (BF16) using sC's layout indexing
|
||||
# 2. Use flat_divide on mC to create gO, then tma_partition, then copy
|
||||
# Path 2: write sO_acc (FP32) -> sC_flat (BF16) -> TMA store to GMEM
|
||||
# sC_flat has epi_s layout (same as what tma_c was created from)
|
||||
|
||||
# Step 1: Cast sO_acc -> sC (BF16)
|
||||
# Step 1: Cast sO_acc -> sC_flat (BF16)
|
||||
for row in cutlass.range(0, 128, unroll=1):
|
||||
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
|
||||
val = sO_acc[row, col]
|
||||
if const_expr(self.normalize):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
val = val * inv_row_sum
|
||||
sC[(row, col), Int32(0), Int32(0)] = val.to(self.o_dtype)
|
||||
sC_flat[row, col] = val.to(self.o_dtype)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
# Step 2: TMA store sC -> GMEM (CUTLASS FMHA reference pattern)
|
||||
# Create gO from mC via flat_divide + slice (same as CUTLASS reference)
|
||||
# Step 2: TMA store sC_flat -> GMEM
|
||||
gO_qdl = cute.flat_divide(
|
||||
mC, cute.select(self.pv_mma_tiler, mode=[0, 1])
|
||||
)
|
||||
gO = gO_qdl[None, None, None, Int32(0), Int32(0)]
|
||||
tOsO, tOgO = cpasync.tma_partition(
|
||||
tOsC, tOgO = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(sC_flat, 0, 2),
|
||||
cute.group_modes(gO, 0, 2),
|
||||
)
|
||||
# Wait for all epilogue warps to finish writing to sC
|
||||
epilog_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, tOsO[None, Int32(0)], tOgO[None, Int32(0)])
|
||||
cute.copy(tma_c, tOsC[None, Int32(0)], tOgO[None, Int32(0)])
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user