diff --git a/tests/fmha_v3_stage_c_example1.py b/tests/fmha_v3_stage_c_example1.py index 9296974e..f0811995 100644 --- a/tests/fmha_v3_stage_c_example1.py +++ b/tests/fmha_v3_stage_c_example1.py @@ -177,7 +177,7 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,None,0)]; tVgV = tVgV[(None,None,None,0)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -214,11 +214,11 @@ class FmhaV3StageCMulti: for kt in cutlass.range(n_kv_tiles, unroll=1): kh = kvp.acquire_and_advance(pk) # GMEM tile: kt (correct K[kt]). SMEM slot: kh.index (ring buffer). - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier) + cute.copy(tma_k, tBgK[(None, kt, None)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier) pk = cutlass.Boolean(1) vh = kvp.acquire_and_advance(pk) # GMEM tile: kt (correct V[kt]). SMEM slot: vh.index (ring buffer). - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier) + cute.copy(tma_v, tVgV[(None, kt, None)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier) pk = cutlass.Boolean(1) kvp.tail()