From ba2cefb6683211bddfa8175e1b185db1a58fce14 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 17:26:56 +0000 Subject: [PATCH] Stage C: manual kv_coord + correct K GMEM slice + O rescale fence Key fixes: 1. GMEM tile coord: manual Int32 kv_coord (not kvh.count) 2. K GMEM slice: (None,None,0,0) keeps mode 1 free (GMEM iter) 3. V GMEM slice: (None,0,None,0) keeps mode 2 free (GMEM iter) 4. Add fence_view_async_tmem_load before O rescale for visibility --- tests/unit/test_fmha_v3_stage_c_full.py | 34 ++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 88b63547..13f7b69a 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -167,13 +167,13 @@ class FmhaV3StageC: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # Only slice GMEM tensors (SMEM tensors from tma_partition are already 2D). - # K from QK MMA: GMEM iter at mode 1 → slice [(None,None,0,0)] keeps modes 0,1 - # V from PV MMA: GMEM iter at mode 2 → slice [(None,0,None,0)] keeps modes 0,2 - # Q: 1 tile only, original slice is fine. - tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode GMEM iter to 0 - tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free for kvh.count - tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free for kvh.count + # GMEM slices: keep the GMEM iteration mode free for kv_coord indexing. + # CUTLASS reference: tKgK = tKgK_kdl[None, None, 0, 0] (keeps TMA_atom + GMEM_iter) + # tVgV = tVgV_dkl[None, 0, None, 0] (keeps TMA_atom + GMEM_iter at mode 2) + # SMEM tensors from tma_partition are already 2D — don't re-slice them. + tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode is fine + tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free + tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -198,20 +198,19 @@ class FmhaV3StageC: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp ===== - # One acquire per kt; K and V both target kvh.barrier. kvh.count == kt. + # GMEM tile coordinate: manual Int32 counter (kv_coord). SMEM slot: kvh.index. + # Pipeline handle .count is NOT a usable GMEM coordinate. if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() + kv_coord = Int32(0) for kt in cutlass.range(n_kv_tiles, unroll=1): kvh = kvp.acquire_and_advance(pk) - # Both transfers decrement the same barrier's tx_count. - # kvh.count is a pipeline-state Int32 (the form cute.copy accepts). - # K: mode 1 is GMEM iter → tBgK[(None, kvh.count)] - # V: mode 2 is GMEM iter → tVgV[(None, kvh.count)] - cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + kv_coord += 1 pk = cutlass.Boolean(1) kvp.tail() @@ -351,9 +350,10 @@ class FmhaV3StageC: cute.arch.fence_view_async_tmem_store() # O rescale for kt > 0. Reads O written by MMA's PV[kt-1]; - # visibility is provided by s_cons.wait_and_advance above - # (acquires on MMA's S[kt] commit, which orders PV[kt-1] before). + # s_cons.wait_and_advance acquires on MMA's S[kt] commit (sequenced + # after PV[kt-1]), but add an explicit tmem_load fence for safety. if kt > 0: + cute.arch.fence_view_async_tmem_load() for i in range(o_col_tiles): tTMEM_LOAD_O_i = cute.make_tensor( tTMEM_LOAD_OtO.iterator + i * corr_tile_size,