diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 5e2067c1..961c46c6 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -369,39 +369,9 @@ class FmhaKernel: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # === O normalization: TMEM -> reg (scale by 1/row_sum) -> TMEM === - # Uses hand-constructed Ld32x32bOp/St32x32bOp atoms (same as correction_rescale). - # The layout mismatch in these atoms introduces ~3% error per round-trip, - # but the correction_rescale atoms (same construction) already use this path. - # TODO: Replace with get_tmem_load_op-derived atoms for zero error. - inv_row_sum = Float32(1.0) / row_sum - - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, - tTMEM_LOADtO.layout, - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, - tTMEM_STOREtO.layout, - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[j] = tTMrO_i[j] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - cute.arch.fence_view_async_tmem_store() - - # Epilogue: TMEM → SMEM → GMEM via TMA store. - # Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) internally. - # Since O is already normalized in TMEM, we apply identity epilogue_op. + # === DIAGNOSTIC: Test epilogue_tma_store WITHOUT any round-trips === + # If get_tmem_load_op reads O correctly from TMEM, this should give cos 0.9999 + # (un-normalized, just raw PV sum). Then we can add normalization back. tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) acc_cons_st = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_acc_stage