Fix epilogue: corr_tile_size=16, proper epi_subtile tuple, match CUTLASS reference

This commit is contained in:
2026-05-22 19:47:57 +00:00
parent 8b93774d70
commit afe5d1ae21

View File

@@ -369,42 +369,22 @@ class FmhaV3StageCMulti:
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# Build paired TMEM load + SMEM store atoms via sm100_utils
# (the same path the reference uses). These two atoms share the
# same register-tile shape, so reg data round-trips losslessly.
#
# epi_subtile narrows the C-tile to a column slab matching one
# SMEM stage; smaller subtiles = more iterations but lower
# register pressure. Use the kernel's pre-computed self.epi_tile
# (computed from compute_epilogue_tile_shape in _setup).
epi_subtile = self.epi_tile
# === Reference-style scaled epilogue (mirrors CUTLASS FMHA correction_epilog) ===
# Pattern: TMEM → reg (paired load atom) → scale in reg → FP32→BF16 in reg
# → SMEM (paired store atom) → TMA SMEM→GMEM. No TMEM round-trip.
# The TMEM tensor we load from is the full O accumulator at the
# PV offset; partition it column-wise into epi_subtile chunks.
tOtO_epi = cute.logical_divide(
tOtO0,
cute.make_layout((self.pv_mma_tiler[0], epi_subtile[1])),
)
# Corresponding identity coord tensor (for partitioning).
cO_full = cute.make_identity_tensor(
(self.pv_mma_tiler[0], self.pv_mma_tiler[1])
)
tOcO_full = pv_thr.partition_C(cO_full)
tOcO_epi = cute.logical_divide(
tOcO_full,
cute.make_layout((self.pv_mma_tiler[0], epi_subtile[1])),
)
# And the SMEM destination tensor (stage 0 of sC).
corr_tile_size = 16 # matches the reference
# Sub-tile the O C-fragment for column-wise iteration
tOtO_i = cute.logical_divide(tOtO0, cute.make_layout((128, corr_tile_size)))
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size)))
tOsO = pv_thr.partition_C(sC[None, None, 0])
tOsO_epi = cute.logical_divide(
tOsO,
cute.make_layout((self.pv_mma_tiler[0], epi_subtile[1])),
)
tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size)))
# Paired atoms via sm100_utils. get_tmem_load_op returns an atom
# configured for the (mma_tiler, dtype, epi_subtile) combo; the
# matching smem_store atom is derived from the resulting tiled
# tmem load so that the register tile shape lines up.
# Paired atoms via sm100_utils (same as CUTLASS reference)
epi_subtile = (self.epi_tile[0], corr_tile_size)
tmem_load_op = utils.sm100.get_tmem_load_op(
self.pv_mma_tiler,
self.c_layout,
@@ -414,7 +394,7 @@ class FmhaV3StageCMulti:
use_2cta_instrs=False,
)
tiled_tmem_load_o = tcgen05.make_tmem_copy(
tmem_load_op, tOtO_epi[(None, None), 0]
tmem_load_op, tOtO_i[(None, None), 0]
)
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
smem_store_op = utils.sm100.get_smem_store_op(
@@ -424,34 +404,33 @@ class FmhaV3StageCMulti:
smem_store_op, tiled_tmem_load_o
)
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_epi[(None, None), None])
tTMEM_LOADsO = thr_tmem_load_o.partition_D(tOsO_epi[(None, None), None])
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_epi[(None, None), None])
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i[(None, None), None])
tTMEM_LOADsO = thr_tmem_load_o.partition_D(tOsO_i[(None, None), None])
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i[(None, None), None])
# Scale = 1/row_sum (fused into the TMEM->SMEM pass).
# Scale = 1/row_sum
inv_row_sum = Float32(1.0) / row_sum
n_epi_tiles = self.pv_mma_tiler[1] // epi_subtile[1]
for i in range(n_epi_tiles):
# Load this column sub-tile from TMEM into registers.
n_corr = self.pv_mma_tiler[1] // corr_tile_size
for i in range(n_corr):
tTMEM_LOADtO_i = tTMEM_LOADtO[None, 0, 0, i]
tTMEM_LOADsO_i = tTMEM_LOADsO[None, 0, 0, i]
tTMrO = cute.make_rmem_tensor(
tTMEM_LOADcO[None, 0, 0, i].shape, self.acc_dtype
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO[None, 0, 0, i], tTMrO)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO)
# Apply normalization in FP32 registers, before dtype conv.
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * inv_row_sum
# Scale in FP32 registers
for j in range(cute.size(tTMrO), vectorize=True):
tTMrO[j] = tTMrO[j] * inv_row_sum
# Convert FP32 -> output dtype (BF16) in registers.
# FP32 → BF16 in registers
tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype)
o_vec = tTMrO.load()
tSMrO.store(o_vec.to(self.o_dtype))
# Store registers -> SMEM via paired atom.
cute.copy(
tiled_smem_store_o, tSMrO, tTMEM_LOADsO[None, 0, 0, i]
)
# Registers SMEM via paired atom
cute.copy(tiled_smem_store_o, tSMrO, tTMEM_LOADsO_i)
cute.arch.fence_view_async_tmem_load()
# Async-proxy fence so the TMA store sees the SMEM writes.