diff --git a/tests/unit/test_fmha_v3_diag.py b/tests/unit/test_fmha_v3_diag.py index e50cd1dc..98d580c8 100644 --- a/tests/unit/test_fmha_v3_diag.py +++ b/tests/unit/test_fmha_v3_diag.py @@ -137,7 +137,9 @@ class FmhaV3Diag: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,None,None,0)]; tBgK = tBgK[(None,None,None,0)]; tVgV = tVgV[(None,0,None,0)] + tAsQ = tAsQ[(None,0,None,0)]; tAgQ = tAgQ[(None,0,None,0)] + tBsK = tBsK[(None,None,0,0)]; tBgK = tBgK[(None,None,0,0)] + tVsV = tVsV[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -163,12 +165,12 @@ class FmhaV3Diag: # TMA LOAD (combined K+V barrier) if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, qh.count, None)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() for kt in cutlass.range(n_kv_tiles, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kvh.count, None)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) pk = cutlass.Boolean(1) kvp.tail() diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 97ef0092..baf59749 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -171,10 +171,11 @@ class FmhaV3StageC: # K: GMEM iter at mode 1 (from QK MMA B-partition, layout K,D,L) # V: GMEM iter at mode 2 (from PV MMA B-partition, layout D,K,L) # Must keep GMEM iter dimension FREE for multi-tile KV loads. - # CUTLASS reference: tQgQ[None,None,0,0], tKgK[None,None,0,0], tVgV[None,0,None,0] - tAgQ = tAgQ[(None,None,None,0)] # Q: keep mode 1 free (only 1 tile, doesn't matter) - tBgK = tBgK[(None,None,None,0)] # K: keep mode 1 free for kvh.count - tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 free for kvh.count + # CUTLASS reference slices: tQgQ[None,None,0,0], tKgK[None,None,0,0], tVgV[None,0,None,0] + # Both GMEM and SMEM sides must be sliced consistently for cute.copy rank matching. + tAsQ = tAsQ[(None,0,None,0)]; tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile only + tBsK = tBsK[(None,None,0,0)]; tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) + tVsV = tVsV[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -202,16 +203,16 @@ class FmhaV3StageC: # One acquire per kt; K and V both target kvh.barrier. kvh.count == kt. if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, qh.count, None)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() for kt in cutlass.range(n_kv_tiles, unroll=1): kvh = kvp.acquire_and_advance(pk) # Both transfers decrement the same barrier's tx_count. # kvh.count is a pipeline-state Int32 (the form cute.copy accepts). - # K: GMEM iter at mode 1, so index (None, kvh.count, None) - # V: GMEM iter at mode 2, so index (None, kvh.count) - cute.copy(tma_k, tBgK[(None, kvh.count, None)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + # K: mode 1 is GMEM iter → tBgK[(None, kvh.count)] + # V: mode 2 is GMEM iter → tVgV[(None, kvh.count)] + cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) pk = cutlass.Boolean(1) kvp.tail()