From fa6388dbf61336974d1175a2fd721f2c267ea511 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 23:28:16 +0000 Subject: [PATCH] auto: pre-test commit --- tests/diag_tma_shapes.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/diag_tma_shapes.py b/tests/diag_tma_shapes.py index 96ee8cf5..0737338e 100644 --- a/tests/diag_tma_shapes.py +++ b/tests/diag_tma_shapes.py @@ -78,9 +78,12 @@ class DiagShapes: a_lay = cute.make_layout(cute.slice_(self.cluster_layout_vmnk,(0,0,None,0)).shape) b_lay = cute.make_layout(cute.slice_(self.cluster_layout_vmnk,(0,None,0,0)).shape) - tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(cute.slice_(self.q_smem_s,(None,None,None,0)),0,3),cute.group_modes(tCgQ,0,3)) - tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(cute.slice_(self.k_smem_s,(None,None,None,0)),0,3),cute.group_modes(tCgK,0,3)) - tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(cute.slice_(self.v_smem_s,(None,None,None,0)),0,3),cute.group_modes(tCgV,0,3)) + sQ = cute.slice_(self.q_smem_s,(None,None,None,0)) + sK = cute.slice_(self.k_smem_s,(None,None,None,0)) + sV = cute.slice_(self.v_smem_s,(None,None,None,0)) + tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) + tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) print(f"tAgQ shape: {cute.shape(tAgQ)} stride: {tAgQ.layout.stride}") print(f"tBgK shape: {cute.shape(tBgK)} stride: {tBgK.layout.stride}")