diff --git a/tests/fmha_v3_real_softmax.py b/tests/fmha_v3_real_softmax.py index 0a3cc287..492f423e 100644 --- a/tests/fmha_v3_real_softmax.py +++ b/tests/fmha_v3_real_softmax.py @@ -299,26 +299,13 @@ class FmhaV3RealSoftmax: # Final O normalization: O = O / row_sum if row_sum != Float32(0.0): inv_row_sum = Float32(1.0) / row_sum - # Load O from TMEM, multiply by 1/row_sum, write back - n_corr = 128 // corr_tile_size - for ci in range(n_corr): - tOtO_ci = cute.make_tensor(tOtO_i.iterator, tOtO_i.layout) - tOcO_ci = cute.make_tensor(tOcO_i.iterator, tOcO_i.layout) - # Sub-tile index offset - tOtO_sub = tOtO_ci[None, ci, None] - tOcO_sub = tOcO_ci[None, ci, None] - # Use the base partitioned tensors with offset - # Actually, just load the full O sub-tile - pass - # Simple approach: load/store via the partitioned tensors tTMEM_LOAD_OrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO, tTMEM_LOAD_OrO) cute.arch.fence_view_async_tmem_load() - # Iterate with proper indexing - n_o_frg = cute.size(tTMEM_LOAD_OrO, mode=[0]) - for fi in range(n_o_frg): + # The register tensor from the O partition is 2D: (frg, corr_tile) + for fi in range(cute.size(tTMEM_LOAD_OrO, mode=[0])): for fj in range(cute.size(tTMEM_LOAD_OrO, mode=[1])): - tTMEM_LOAD_OrO[fi, fj, None] = tTMEM_LOAD_OrO[fi, fj, None] * inv_row_sum + tTMEM_LOAD_OrO[fi, fj] = tTMEM_LOAD_OrO[fi, fj] * inv_row_sum cute.copy(tiled_tmem_store_o, tTMEM_LOAD_OrO, tTMEM_STORE_OtO) cute.arch.fence_view_async_tmem_store()