auto: pre-test commit

This commit is contained in:
2026-05-22 23:28:16 +00:00
parent 9c5122f180
commit fa6388dbf6

View File

@@ -78,9 +78,12 @@ class DiagShapes:
a_lay = cute.make_layout(cute.slice_(self.cluster_layout_vmnk,(0,0,None,0)).shape)
b_lay = cute.make_layout(cute.slice_(self.cluster_layout_vmnk,(0,None,0,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(cute.slice_(self.q_smem_s,(None,None,None,0)),0,3),cute.group_modes(tCgQ,0,3))
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(cute.slice_(self.k_smem_s,(None,None,None,0)),0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(cute.slice_(self.v_smem_s,(None,None,None,0)),0,3),cute.group_modes(tCgV,0,3))
sQ = cute.slice_(self.q_smem_s,(None,None,None,0))
sK = cute.slice_(self.k_smem_s,(None,None,None,0))
sV = cute.slice_(self.v_smem_s,(None,None,None,0))
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
print(f"tAgQ shape: {cute.shape(tAgQ)} stride: {tAgQ.layout.stride}")
print(f"tBgK shape: {cute.shape(tBgK)} stride: {tBgK.layout.stride}")