Fix group_modes range in TMA shape diag

This commit is contained in:
2026-05-22 21:53:22 +00:00
parent be27720cb2
commit b64227e5b6

View File

@@ -59,10 +59,10 @@ class TmaShapeDiag:
tCgV = pv_thr.partition_B(gV)
a_lay = cute.make_layout(1)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,2),cute.group_modes(tCgQ,0,3))
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(q_s,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(1)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,2),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,2),cute.group_modes(tCgV,0,3))
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(k_s,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(v_s,0,3),cute.group_modes(tCgV,0,3))
print(f"tAgQ shape: {cute.shape(tAgQ)} rank: {tAgQ.rank}")
print(f"tBgK shape: {cute.shape(tBgK)} rank: {tBgK.rank}")