From afe5d1ae21079001bf1e2a3fbb0b7c30e47f23ab Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 19:47:57 +0000 Subject: [PATCH] Fix epilogue: corr_tile_size=16, proper epi_subtile tuple, match CUTLASS reference --- tests/fmha_v3_stage_c_example7.py | 79 ++++++++++++------------------- 1 file changed, 29 insertions(+), 50 deletions(-) diff --git a/tests/fmha_v3_stage_c_example7.py b/tests/fmha_v3_stage_c_example7.py index aa6ad2e1..374e4124 100644 --- a/tests/fmha_v3_stage_c_example7.py +++ b/tests/fmha_v3_stage_c_example7.py @@ -369,42 +369,22 @@ class FmhaV3StageCMulti: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # Build paired TMEM load + SMEM store atoms via sm100_utils - # (the same path the reference uses). These two atoms share the - # same register-tile shape, so reg data round-trips losslessly. - # - # epi_subtile narrows the C-tile to a column slab matching one - # SMEM stage; smaller subtiles = more iterations but lower - # register pressure. Use the kernel's pre-computed self.epi_tile - # (computed from compute_epilogue_tile_shape in _setup). - epi_subtile = self.epi_tile + # === Reference-style scaled epilogue (mirrors CUTLASS FMHA correction_epilog) === + # Pattern: TMEM → reg (paired load atom) → scale in reg → FP32→BF16 in reg + # → SMEM (paired store atom) → TMA SMEM→GMEM. No TMEM round-trip. - # The TMEM tensor we load from is the full O accumulator at the - # PV offset; partition it column-wise into epi_subtile chunks. - tOtO_epi = cute.logical_divide( - tOtO0, - cute.make_layout((self.pv_mma_tiler[0], epi_subtile[1])), - ) - # Corresponding identity coord tensor (for partitioning). - cO_full = cute.make_identity_tensor( - (self.pv_mma_tiler[0], self.pv_mma_tiler[1]) - ) - tOcO_full = pv_thr.partition_C(cO_full) - tOcO_epi = cute.logical_divide( - tOcO_full, - cute.make_layout((self.pv_mma_tiler[0], epi_subtile[1])), - ) - # And the SMEM destination tensor (stage 0 of sC). + corr_tile_size = 16 # matches the reference + + # Sub-tile the O C-fragment for column-wise iteration + tOtO_i = cute.logical_divide(tOtO0, cute.make_layout((128, corr_tile_size))) + cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO = pv_thr.partition_C(cO) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size))) tOsO = pv_thr.partition_C(sC[None, None, 0]) - tOsO_epi = cute.logical_divide( - tOsO, - cute.make_layout((self.pv_mma_tiler[0], epi_subtile[1])), - ) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size))) - # Paired atoms via sm100_utils. get_tmem_load_op returns an atom - # configured for the (mma_tiler, dtype, epi_subtile) combo; the - # matching smem_store atom is derived from the resulting tiled - # tmem load so that the register tile shape lines up. + # Paired atoms via sm100_utils (same as CUTLASS reference) + epi_subtile = (self.epi_tile[0], corr_tile_size) tmem_load_op = utils.sm100.get_tmem_load_op( self.pv_mma_tiler, self.c_layout, @@ -414,7 +394,7 @@ class FmhaV3StageCMulti: use_2cta_instrs=False, ) tiled_tmem_load_o = tcgen05.make_tmem_copy( - tmem_load_op, tOtO_epi[(None, None), 0] + tmem_load_op, tOtO_i[(None, None), 0] ) thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) smem_store_op = utils.sm100.get_smem_store_op( @@ -424,34 +404,33 @@ class FmhaV3StageCMulti: smem_store_op, tiled_tmem_load_o ) - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_epi[(None, None), None]) - tTMEM_LOADsO = thr_tmem_load_o.partition_D(tOsO_epi[(None, None), None]) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_epi[(None, None), None]) + tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i[(None, None), None]) + tTMEM_LOADsO = thr_tmem_load_o.partition_D(tOsO_i[(None, None), None]) + tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i[(None, None), None]) - # Scale = 1/row_sum (fused into the TMEM->SMEM pass). + # Scale = 1/row_sum inv_row_sum = Float32(1.0) / row_sum - n_epi_tiles = self.pv_mma_tiler[1] // epi_subtile[1] - for i in range(n_epi_tiles): - # Load this column sub-tile from TMEM into registers. + n_corr = self.pv_mma_tiler[1] // corr_tile_size + for i in range(n_corr): + tTMEM_LOADtO_i = tTMEM_LOADtO[None, 0, 0, i] + tTMEM_LOADsO_i = tTMEM_LOADsO[None, 0, 0, i] tTMrO = cute.make_rmem_tensor( tTMEM_LOADcO[None, 0, 0, i].shape, self.acc_dtype ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO[None, 0, 0, i], tTMrO) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO) - # Apply normalization in FP32 registers, before dtype conv. - for k in cutlass.range(cute.size(tTMrO), vectorize=True): - tTMrO[k] = tTMrO[k] * inv_row_sum + # Scale in FP32 registers + for j in range(cute.size(tTMrO), vectorize=True): + tTMrO[j] = tTMrO[j] * inv_row_sum - # Convert FP32 -> output dtype (BF16) in registers. + # FP32 → BF16 in registers tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype) o_vec = tTMrO.load() tSMrO.store(o_vec.to(self.o_dtype)) - # Store registers -> SMEM via paired atom. - cute.copy( - tiled_smem_store_o, tSMrO, tTMEM_LOADsO[None, 0, 0, i] - ) + # Registers → SMEM via paired atom + cute.copy(tiled_smem_store_o, tSMrO, tTMEM_LOADsO_i) cute.arch.fence_view_async_tmem_load() # Async-proxy fence so the TMA store sees the SMEM writes.