diff --git a/tests/fmha_v3_stage_c_example7.py b/tests/fmha_v3_stage_c_example7.py index dac7347d..6e5ae19b 100644 --- a/tests/fmha_v3_stage_c_example7.py +++ b/tests/fmha_v3_stage_c_example7.py @@ -203,22 +203,19 @@ class FmhaV3StageCMulti: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp ===== - # GMEM tile coordinate: use the cutlass.range induction variable kt - # directly. CuTeDSL's `cutlass.range` doesn't auto-detect a Python `+=` - # rebinding as a loop-carried iter_args update — the JIT traces the - # body once and captures whatever value `kv_coord` had at trace time, - # so an outer `kv_coord = Int32(0)` plus a `kv_coord += 1` inside the - # loop bakes 0 into every iteration's TMA descriptor at runtime. - # The induction variable IS the loop-carried state, properly tracked. + # Use the SAME pattern as the working test_fmha_v3_diag.py: + # kv_coord = Int32(0+0) + kv_coord += 1 in cutlass.range if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() - for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): + kv_coord = Int32(0 + 0) + for kt in cutlass.range(self.s_k // 128, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + kv_coord += 1 pk = cutlass.Boolean(1) kvp.tail() @@ -292,16 +289,34 @@ class FmhaV3StageCMulti: # Per-tile softmax loop. # Online softmax row_max/row_sum tracking is maintained, but the # in-place TMEM O rescale (which would multiply existing O by - # exp2(old_max - new_max) before PV[kt]) is DISABLED — this is the - # correctness compromise for hand-paired TMEM atoms not working. - # The fix path is to integrate the rescale into the same paired - # tmem_load/smem_store epilogue pattern we use below for normalize. - # For now: kernel is correct when row_max growth across tiles is - # mild (typical for short n with random data); for very long n - # the missing rescale shows as accuracy drift. + # O rescale setup: same correction_rescale pattern as the final normalize + corr_tile_size_rs = 16 + cO_rs = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO_rs = pv_thr.partition_C(cO_rs) + tOtO_rs_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size_rs))) + tOcO_rs_i_layout = cute.composition(tOcO_rs.layout, cute.make_layout((128, corr_tile_size_rs))) + tOtO_rs_i = cute.make_tensor(tOtO0.iterator, tOtO_rs_i_layout) + tOcO_rs_i = cute.make_tensor(tOcO_rs.iterator, tOcO_rs_i_layout) + tmem_load_o_rs_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size_rs)), self.acc_dtype) + tmem_store_o_rs_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size_rs)), self.acc_dtype) + tiled_tmem_load_o_rs = tcgen05.make_tmem_copy(tmem_load_o_rs_atom, tOtO_rs_i) + tiled_tmem_store_o_rs = tcgen05.make_tmem_copy(tmem_store_o_rs_atom, tOtO_rs_i) + thr_tmem_load_o_rs = tiled_tmem_load_o_rs.get_slice(sfw_idx) + thr_tmem_store_o_rs = tiled_tmem_store_o_rs.get_slice(sfw_idx) + tTMEM_LOAD_OtO_rs = thr_tmem_load_o_rs.partition_S(tOtO_rs_i) + tTMEM_LOAD_OcO_rs = thr_tmem_load_o_rs.partition_D(tOcO_rs_i) + tTMEM_STORE_OtO_rs = thr_tmem_store_o_rs.partition_D(tOtO_rs_i) + tTMrO_rs = cute.make_rmem_tensor( + (tTMEM_LOAD_OcO_rs.shape, 128 // corr_tile_size_rs), self.acc_dtype) + + row_max = -Float32.inf + row_sum = Float32(0.0) + scale_log2 = Float32(self.scale_softmax_log2) + for kt in range(n_kv_tiles): si_handle = s_cons.wait_and_advance() - # Load S[kt] tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) @@ -329,6 +344,27 @@ class FmhaV3StageCMulti: acc_scale = Float32(0.0) row_sum *= acc_scale + # O rescale: multiply existing O by acc_scale = exp2(old_max - new_max) + # Uses the same correction_rescale pattern verified for final normalize. + if kt > 0: + for ci in range(HEAD_DIM // corr_tile_size_rs): + tTMrO_rs_i_ = tTMrO_rs[None, ci] + tTMrO_rs_i_layout = cute.composition( + tTMrO_rs_i_.layout, cute.make_layout(tTMrO_rs.shape[0]) + ) + tTMrO_rs_i = cute.make_tensor(tTMrO_rs_i_.iterator, tTMrO_rs_i_layout) + tTMEM_LOAD_OtO_rs_i = cute.make_tensor( + tTMEM_LOAD_OtO_rs.iterator + ci * corr_tile_size_rs, tTMEM_LOAD_OtO_rs.layout + ) + tTMEM_STORE_OtO_rs_i = cute.make_tensor( + tTMEM_STORE_OtO_rs.iterator + ci * corr_tile_size_rs, tTMEM_STORE_OtO_rs.layout + ) + cute.copy(tiled_tmem_load_o_rs, tTMEM_LOAD_OtO_rs_i, tTMrO_rs_i) + for j in cutlass.range(cute.size(tTMrO_rs_i), vectorize=True): + tTMrO_rs_i[j] = tTMrO_rs_i[j] * acc_scale + cute.copy(tiled_tmem_store_o_rs, tTMrO_rs_i, tTMEM_STORE_OtO_rs_i) + cute.arch.fence_view_async_tmem_store() + # Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum, # store BF16 P through the FP32-backed register bridge. rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)