From 9c331de7ba263ffa1ecc3fc86eddb6ed79d823cf Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 02:54:54 +0000 Subject: [PATCH] fix: revert to composition layout for hand-constructed atoms (matching CUTLASS) --- tests/unit/test_fmha_v3_stage_c.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 2ba1c95a..c4a52a8f 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -234,12 +234,13 @@ class FmhaV3StageCMulti: row_sum = Float32(0.0) scale_log2 = Float32(self.scale_softmax_log2) - # O rescale atoms (hand-constructed, for per-tile O *= acc_scale) - # Use logical_divide (not composition) to match get_tmem_load_op layout + # O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale) corr_tile_size = 16 tOcO = pv_thr.partition_C(cS) - tOtO_epi_r = cute.logical_divide(tOtO0, cute.make_layout((128, corr_tile_size))) - tOcO_epi_r = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size))) + tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) + tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) tmem_load_o_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype, @@ -248,13 +249,13 @@ class FmhaV3StageCMulti: tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype, ) - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_epi_r[(None, None), 0]) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_epi_r[(None, None), 0]) + tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) + tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_epi_r[(None, None), None]) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_epi_r[(None, None), None]) - tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_epi_r[(None, None), None]) + tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) + tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) + tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) n_corr_tiles = HEAD_DIM // corr_tile_size for kt in range(self.n_kv_tiles):