From f4b7bed463bdf2d5e8bc24c12db4c0d43f210dd4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 00:25:16 +0000 Subject: [PATCH] Fix final normalize: use working 2D register tensor pattern from working_softmax_maybe.py The make_rmem_tensor(tTMEM_LOADcO.shape) creates a 1D tensor that doesn't match the paired atom layout. The working pattern uses a 2D register tensor with sub-tile composition (tTMrO_i_ = tTMrO[None, i] + composition). --- tests/unit/test_fmha_v3_stage_c.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index c738df23..e8dfc384 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -406,22 +406,30 @@ class FmhaV3StageCMulti: final_o_bar.arrive_and_wait() # === Final O normalization: O *= 1/row_sum === - # Uses the same paired atoms and sub-tile pattern as the per-tile rescale. + # Uses the working 2D register tensor pattern from working_softmax_maybe.py. inv_row_sum = Float32(1.0) / row_sum + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + for i in range(HEAD_DIM // corr_tile_size): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) tTMEM_LOADtO_i = cute.make_tensor( tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout ) tTMEM_STOREtO_i = cute.make_tensor( tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout ) - tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO) - cute.arch.fence_view_async_tmem_load() - for k in cutlass.range(cute.size(tTMrO), vectorize=True): - tTMrO[k] = tTMrO[k] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i) + + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[j] = tTMrO_i[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) cute.arch.fence_view_async_tmem_store()