Fix group_modes range in TMA shape diag
This commit is contained in:
@@ -59,10 +59,10 @@ class TmaShapeDiag:
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
|
||||
a_lay = cute.make_layout(1)
|
||||
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,2),cute.group_modes(tCgQ,0,3))
|
||||
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,3),cute.group_modes(tCgQ,0,3))
|
||||
b_lay = cute.make_layout(1)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,2),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,2),cute.group_modes(tCgV,0,3))
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,3),cute.group_modes(tCgV,0,3))
|
||||
|
||||
print(f"tAgQ shape: {cute.shape(tAgQ)} rank: {tAgQ.rank}")
|
||||
print(f"tBgK shape: {cute.shape(tBgK)} rank: {tBgK.rank}")
|
||||
|
||||
Reference in New Issue
Block a user