From b64227e5b66d0add206602dd4a1456baebeacf94 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:53:22 +0000 Subject: [PATCH] Fix group_modes range in TMA shape diag --- tests/unit/test_tma_shapes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_tma_shapes.py b/tests/unit/test_tma_shapes.py index e1f25c3c..b66870bb 100644 --- a/tests/unit/test_tma_shapes.py +++ b/tests/unit/test_tma_shapes.py @@ -59,10 +59,10 @@ class TmaShapeDiag: tCgV = pv_thr.partition_B(gV) a_lay = cute.make_layout(1) - tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,2),cute.group_modes(tCgQ,0,3)) + tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,3),cute.group_modes(tCgQ,0,3)) b_lay = cute.make_layout(1) - tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,2),cute.group_modes(tCgK,0,3)) - tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,2),cute.group_modes(tCgV,0,3)) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,3),cute.group_modes(tCgK,0,3)) + tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,3),cute.group_modes(tCgV,0,3)) print(f"tAgQ shape: {cute.shape(tAgQ)} rank: {tAgQ.rank}") print(f"tBgK shape: {cute.shape(tBgK)} rank: {tBgK.rank}")