Add SMEM accumulator for n_kv_tiles>1: O load from TMEM, accumulate in sO_acc, TMA store from sC
This commit is contained in:
@@ -66,6 +66,8 @@ class FmhaKernel:
|
||||
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
|
||||
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
|
||||
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
|
||||
# SMEM accumulator: needed when n_kv_tiles > 1 (TMEM round-trip is broken)
|
||||
self.use_smem_accumulator = self.n_kv_tiles > 1
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
@@ -217,6 +219,14 @@ class FmhaKernel:
|
||||
_p_swizzle = cute.make_layout(((1,1),1,(1,1),1))
|
||||
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)
|
||||
|
||||
# SMEM accumulator: FP32 [128, pv_n_tile] row-major
|
||||
# Only needed for n_kv_tiles > 1 (TMEM round-trip is broken)
|
||||
if const_expr(self.use_smem_accumulator):
|
||||
sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1))
|
||||
sO_acc = smem.allocate_tensor(element_type=Float32, layout=sO_acc_layout, byte_alignment=128)
|
||||
else:
|
||||
sO_acc = smem.allocate_tensor(element_type=Float32, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128)
|
||||
|
||||
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
|
||||
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
|
||||
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
|
||||
@@ -342,7 +352,7 @@ class FmhaKernel:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
sh.commit()
|
||||
softmax_done_bar.arrive_and_wait()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0 and not self.use_smem_accumulator)
|
||||
if not self.use_smem_p:
|
||||
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
|
||||
@@ -405,6 +415,21 @@ class FmhaKernel:
|
||||
# This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192))
|
||||
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
|
||||
|
||||
# O load atoms: one-way TMEM->REGS read of O_kt after PV (for SMEM accumulator)
|
||||
# Only needed when use_smem_accumulator (n_kv_tiles > 1)
|
||||
if const_expr(self.use_smem_accumulator):
|
||||
o_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_o_load = tcgen05.make_tmem_copy(o_load_atom, tOtO0)
|
||||
thr_o_load = tiled_o_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtO = thr_o_load.partition_S(tOtO0)
|
||||
cO = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
tTMEM_LOADcO = thr_o_load.partition_D(tOcO)
|
||||
|
||||
# Zero-initialize sO_acc
|
||||
for i in cutlass.range(0, cute.size(sO_acc), unroll=1):
|
||||
sO_acc[i] = Float32(0.0)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
@@ -539,37 +564,79 @@ class FmhaKernel:
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# --- SMEM accumulator: load O_kt from TMEM, accumulate in sO_acc ---
|
||||
if const_expr(self.use_smem_accumulator):
|
||||
pv_done_bar.arrive_and_wait()
|
||||
rO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_o_load, tTMEM_LOADtO, rO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# sO_acc = acc_scale * sO_acc + O_kt
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
coord = tTMEM_LOADcO[(j0, 0), j1, 0, 0]
|
||||
row = coord[0]
|
||||
col = coord[1]
|
||||
old_val = sO_acc[row, col]
|
||||
new_val = acc_scale * old_val + rO[(j0, 0), j1, 0, 0]
|
||||
sO_acc[row, col] = new_val
|
||||
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# ============================================================
|
||||
# EPILOGUE: TMA store O to GMEM + compute LSE
|
||||
# EPILOGUE: store O to GMEM + compute LSE
|
||||
# ============================================================
|
||||
# The raw un-normalized O in TMEM is perfect (cos 0.999998).
|
||||
# We use epilogue_tma_store which reads O from TMEM directly via
|
||||
# the correct get_tmem_load_op layout — no round-trip needed.
|
||||
#
|
||||
# For multi-KV-tile: the paired-atom O rescale above (kt>0) ensures
|
||||
# O is correctly rescaled before this epilogue reads it.
|
||||
#
|
||||
# External normalization (D5a path): kernel outputs un-normalized O +
|
||||
# LSE + row_sum. Caller normalizes using O_norm = O_unnorm / row_sum.
|
||||
# This is exact and composes with D5c sink bias merge.
|
||||
# For n_kv_tiles=1: O is in TMEM (from PV). Use epilogue_tma_store.
|
||||
# For n_kv_tiles>1: O is in sO_acc (SMEM). Write to sC, then TMA store.
|
||||
# ============================================================
|
||||
|
||||
# TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM)
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
|
||||
0, const_expr(lambda x: x), (0, 0, 0),
|
||||
acc_cons_st, acc_pipe, c_pipe,
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
if const_expr(not self.use_smem_accumulator):
|
||||
# Path 1: epilogue_tma_store (reads O from TMEM)
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
|
||||
0, const_expr(lambda x: x), (0, 0, 0),
|
||||
acc_cons_st, acc_pipe, c_pipe,
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
else:
|
||||
# Path 2: sO_acc -> sC -> TMA store to GMEM
|
||||
# Cast sO_acc (FP32) -> sC (BF16) using sC's layout indexing.
|
||||
# sC layout from make_smem_layout_epi: ((M,N), ?, num_c_stage, ...).
|
||||
# Write via sC's native coordinate system.
|
||||
for row in cutlass.range(0, 128, unroll=1):
|
||||
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
|
||||
sC[(row, col), Int32(0), Int32(0)] = sO_acc[row, col].to(self.o_dtype)
|
||||
|
||||
# TMA store sC -> GMEM using cpasync.tma_partition
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
epilog_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
tCgC_xfm = transform_partitioned_tensor_layout(tCgC)
|
||||
tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
c_pipe = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_c_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
)
|
||||
c_pipe.producer_acquire()
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, bSG_sC[(Int32(0),)], bSG_gC[(Int32(0),)])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_tail()
|
||||
|
||||
# Compute LSE: lse = ln(row_sum) + row_max * ln(2)
|
||||
# Only when emitting un-normalized output (D5a path).
|
||||
|
||||
Reference in New Issue
Block a user