diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index f8549eff..12264d5e 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -179,12 +179,11 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # TMA source tensor slices: keep the GMEM tile dimension (mode 4) free - # tBgK shape: (1, 1, 1, 1, 2, 1, 1, 1) — 8 modes, mode 4 = kv_tiles - # tVgV shape: (1, 1, 1, 1, 2, 1, 1, 1) — 8 modes, mode 4 = kv_tiles + # After tma_partition, tBgK/tVgV have 8 modes. + # Mode 4 is the GMEM tile iteration axis (size = n_kv_tiles). + # Do NOT pre-slice — index all 8 modes explicitly in cute.copy. + # tAgQ is fine with 4-mode slice (Q has only 1 tile). tAgQ = tAgQ[(None,0,None,0)] - tBgK = tBgK[(None,None,None,None,None,None,None,None)] # No-op, use full indexing in copy - tVgV = tVgV[(None,None,None,None,None,None,None,None)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -222,8 +221,8 @@ class FmhaV3StageCMulti: kvp.reset(); pk = kvp.try_acquire() for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) pk = cutlass.Boolean(1) kvp.tail()