Fix epilogue: corr_tile_size=16, proper epi_subtile tuple, match CUTLASS reference
This commit is contained in:
@@ -369,42 +369,22 @@ class FmhaV3StageCMulti:
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# Build paired TMEM load + SMEM store atoms via sm100_utils
|
||||
# (the same path the reference uses). These two atoms share the
|
||||
# same register-tile shape, so reg data round-trips losslessly.
|
||||
#
|
||||
# epi_subtile narrows the C-tile to a column slab matching one
|
||||
# SMEM stage; smaller subtiles = more iterations but lower
|
||||
# register pressure. Use the kernel's pre-computed self.epi_tile
|
||||
# (computed from compute_epilogue_tile_shape in _setup).
|
||||
epi_subtile = self.epi_tile
|
||||
# === Reference-style scaled epilogue (mirrors CUTLASS FMHA correction_epilog) ===
|
||||
# Pattern: TMEM → reg (paired load atom) → scale in reg → FP32→BF16 in reg
|
||||
# → SMEM (paired store atom) → TMA SMEM→GMEM. No TMEM round-trip.
|
||||
|
||||
# The TMEM tensor we load from is the full O accumulator at the
|
||||
# PV offset; partition it column-wise into epi_subtile chunks.
|
||||
tOtO_epi = cute.logical_divide(
|
||||
tOtO0,
|
||||
cute.make_layout((self.pv_mma_tiler[0], epi_subtile[1])),
|
||||
)
|
||||
# Corresponding identity coord tensor (for partitioning).
|
||||
cO_full = cute.make_identity_tensor(
|
||||
(self.pv_mma_tiler[0], self.pv_mma_tiler[1])
|
||||
)
|
||||
tOcO_full = pv_thr.partition_C(cO_full)
|
||||
tOcO_epi = cute.logical_divide(
|
||||
tOcO_full,
|
||||
cute.make_layout((self.pv_mma_tiler[0], epi_subtile[1])),
|
||||
)
|
||||
# And the SMEM destination tensor (stage 0 of sC).
|
||||
corr_tile_size = 16 # matches the reference
|
||||
|
||||
# Sub-tile the O C-fragment for column-wise iteration
|
||||
tOtO_i = cute.logical_divide(tOtO0, cute.make_layout((128, corr_tile_size)))
|
||||
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size)))
|
||||
tOsO = pv_thr.partition_C(sC[None, None, 0])
|
||||
tOsO_epi = cute.logical_divide(
|
||||
tOsO,
|
||||
cute.make_layout((self.pv_mma_tiler[0], epi_subtile[1])),
|
||||
)
|
||||
tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size)))
|
||||
|
||||
# Paired atoms via sm100_utils. get_tmem_load_op returns an atom
|
||||
# configured for the (mma_tiler, dtype, epi_subtile) combo; the
|
||||
# matching smem_store atom is derived from the resulting tiled
|
||||
# tmem load so that the register tile shape lines up.
|
||||
# Paired atoms via sm100_utils (same as CUTLASS reference)
|
||||
epi_subtile = (self.epi_tile[0], corr_tile_size)
|
||||
tmem_load_op = utils.sm100.get_tmem_load_op(
|
||||
self.pv_mma_tiler,
|
||||
self.c_layout,
|
||||
@@ -414,7 +394,7 @@ class FmhaV3StageCMulti:
|
||||
use_2cta_instrs=False,
|
||||
)
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(
|
||||
tmem_load_op, tOtO_epi[(None, None), 0]
|
||||
tmem_load_op, tOtO_i[(None, None), 0]
|
||||
)
|
||||
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
|
||||
smem_store_op = utils.sm100.get_smem_store_op(
|
||||
@@ -424,34 +404,33 @@ class FmhaV3StageCMulti:
|
||||
smem_store_op, tiled_tmem_load_o
|
||||
)
|
||||
|
||||
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_epi[(None, None), None])
|
||||
tTMEM_LOADsO = thr_tmem_load_o.partition_D(tOsO_epi[(None, None), None])
|
||||
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_epi[(None, None), None])
|
||||
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i[(None, None), None])
|
||||
tTMEM_LOADsO = thr_tmem_load_o.partition_D(tOsO_i[(None, None), None])
|
||||
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i[(None, None), None])
|
||||
|
||||
# Scale = 1/row_sum (fused into the TMEM->SMEM pass).
|
||||
# Scale = 1/row_sum
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
n_epi_tiles = self.pv_mma_tiler[1] // epi_subtile[1]
|
||||
for i in range(n_epi_tiles):
|
||||
# Load this column sub-tile from TMEM into registers.
|
||||
n_corr = self.pv_mma_tiler[1] // corr_tile_size
|
||||
for i in range(n_corr):
|
||||
tTMEM_LOADtO_i = tTMEM_LOADtO[None, 0, 0, i]
|
||||
tTMEM_LOADsO_i = tTMEM_LOADsO[None, 0, 0, i]
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
tTMEM_LOADcO[None, 0, 0, i].shape, self.acc_dtype
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO[None, 0, 0, i], tTMrO)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO)
|
||||
|
||||
# Apply normalization in FP32 registers, before dtype conv.
|
||||
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
|
||||
tTMrO[k] = tTMrO[k] * inv_row_sum
|
||||
# Scale in FP32 registers
|
||||
for j in range(cute.size(tTMrO), vectorize=True):
|
||||
tTMrO[j] = tTMrO[j] * inv_row_sum
|
||||
|
||||
# Convert FP32 -> output dtype (BF16) in registers.
|
||||
# FP32 → BF16 in registers
|
||||
tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype)
|
||||
o_vec = tTMrO.load()
|
||||
tSMrO.store(o_vec.to(self.o_dtype))
|
||||
|
||||
# Store registers -> SMEM via paired atom.
|
||||
cute.copy(
|
||||
tiled_smem_store_o, tSMrO, tTMEM_LOADsO[None, 0, 0, i]
|
||||
)
|
||||
# Registers → SMEM via paired atom
|
||||
cute.copy(tiled_smem_store_o, tSMrO, tTMEM_LOADsO_i)
|
||||
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
# Async-proxy fence so the TMA store sees the SMEM writes.
|
||||
|
||||
Reference in New Issue
Block a user