diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index c738df23..e8dfc384 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -406,22 +406,30 @@ class FmhaV3StageCMulti: final_o_bar.arrive_and_wait() # === Final O normalization: O *= 1/row_sum === - # Uses the same paired atoms and sub-tile pattern as the per-tile rescale. + # Uses the working 2D register tensor pattern from working_softmax_maybe.py. inv_row_sum = Float32(1.0) / row_sum + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + for i in range(HEAD_DIM // corr_tile_size): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) tTMEM_LOADtO_i = cute.make_tensor( tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout ) tTMEM_STOREtO_i = cute.make_tensor( tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout ) - tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO) - cute.arch.fence_view_async_tmem_load() - for k in cutlass.range(cute.size(tTMrO), vectorize=True): - tTMrO[k] = tTMrO[k] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i) + + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[j] = tTMrO_i[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) cute.arch.fence_view_async_tmem_store()