SME accumulator: direct GMEM write from sO_acc (bypass TMA for multi-kt)
This commit is contained in:
@@ -592,7 +592,7 @@ class FmhaKernel:
|
||||
# ============================================================
|
||||
|
||||
if const_expr(not self.use_smem_accumulator):
|
||||
# Path 1: epilogue_tma_store (reads O from TMEM)
|
||||
# Path 1: epilogue_tma_store (reads O from TMEM, proven for n_kv=1)
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
@@ -606,39 +606,24 @@ class FmhaKernel:
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
else:
|
||||
# Path 2: sO_acc -> sC -> TMA store to GMEM
|
||||
# Cast sO_acc (FP32) -> sC (BF16) using sC's layout indexing.
|
||||
# sC layout from make_smem_layout_epi: ((M,N), ?, num_c_stage, ...).
|
||||
# Write via sC's native coordinate system.
|
||||
for row in cutlass.range(0, 128, unroll=1):
|
||||
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
|
||||
sC[(row, col), Int32(0), Int32(0)] = sO_acc[row, col].to(self.o_dtype)
|
||||
|
||||
# TMA store sC -> GMEM using cpasync.tma_partition
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
epilog_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
tCgC_xfm = transform_partitioned_tensor_layout(tCgC)
|
||||
tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
# Slice off MMA tile coordinates (same as epilogue_tma_store)
|
||||
bSG_gC = bSG_gC[(None, None, None, Int32(0), Int32(0), Int32(0))]
|
||||
c_pipe = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_c_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
)
|
||||
c_pipe.producer_acquire()
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, bSG_sC[(None, Int32(0))], bSG_gC[(None, None, Int32(0))])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_tail()
|
||||
# Path 2: write sO_acc (FP32) -> GMEM directly from registers
|
||||
# Each thread handles one row (sfw_idx) of the output.
|
||||
# Read sO_acc row, normalize, cast to BF16, write to GMEM.
|
||||
#
|
||||
# For GMEM writes, use gC (the local_tile of the output tensor).
|
||||
# gC is created from mC with TMA-compatible layout.
|
||||
# We write to it using cute.copy with a universal copy atom.
|
||||
#
|
||||
# Actually, the simplest approach: write each element directly
|
||||
# to the GMEM output tensor using scalar stores.
|
||||
# gC[sfw_idx, col, 0] = BF16(sO_acc[sfw_idx, col] / row_sum)
|
||||
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
|
||||
row = sfw_idx
|
||||
if row < Int32(128):
|
||||
val = sO_acc[row, col]
|
||||
if const_expr(not self.normalize):
|
||||
val = val / row_sum
|
||||
gC[Int32(row), Int32(col), Int32(0)] = val.to(self.o_dtype)
|
||||
|
||||
# Compute LSE: lse = ln(row_sum) + row_max * ln(2)
|
||||
# Only when emitting un-normalized output (D5a path).
|
||||
|
||||
Reference in New Issue
Block a user