diff --git a/tests/unit/test_tma_shapes.py b/tests/unit/test_tma_shapes.py index e1f25c3c..b66870bb 100644 --- a/tests/unit/test_tma_shapes.py +++ b/tests/unit/test_tma_shapes.py @@ -59,10 +59,10 @@ class TmaShapeDiag: tCgV = pv_thr.partition_B(gV) a_lay = cute.make_layout(1) - tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,2),cute.group_modes(tCgQ,0,3)) + tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,3),cute.group_modes(tCgQ,0,3)) b_lay = cute.make_layout(1) - tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,2),cute.group_modes(tCgK,0,3)) - tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,2),cute.group_modes(tCgV,0,3)) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,3),cute.group_modes(tCgK,0,3)) + tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,3),cute.group_modes(tCgV,0,3)) print(f"tAgQ shape: {cute.shape(tAgQ)} rank: {tAgQ.rank}") print(f"tBgK shape: {cute.shape(tBgK)} rank: {tBgK.rank}")