auto: pre-test commit
This commit is contained in:
@@ -78,9 +78,12 @@ class DiagShapes:
|
||||
|
||||
a_lay = cute.make_layout(cute.slice_(self.cluster_layout_vmnk,(0,0,None,0)).shape)
|
||||
b_lay = cute.make_layout(cute.slice_(self.cluster_layout_vmnk,(0,None,0,0)).shape)
|
||||
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(cute.slice_(self.q_smem_s,(None,None,None,0)),0,3),cute.group_modes(tCgQ,0,3))
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(cute.slice_(self.k_smem_s,(None,None,None,0)),0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(cute.slice_(self.v_smem_s,(None,None,None,0)),0,3),cute.group_modes(tCgV,0,3))
|
||||
sQ = cute.slice_(self.q_smem_s,(None,None,None,0))
|
||||
sK = cute.slice_(self.k_smem_s,(None,None,None,0))
|
||||
sV = cute.slice_(self.v_smem_s,(None,None,None,0))
|
||||
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
|
||||
print(f"tAgQ shape: {cute.shape(tAgQ)} stride: {tAgQ.layout.stride}")
|
||||
print(f"tBgK shape: {cute.shape(tBgK)} stride: {tBgK.layout.stride}")
|
||||
|
||||
Reference in New Issue
Block a user