FIX: only slice GMEM tensors (SMEM already 2D from tma_partition)
This commit is contained in:
@@ -137,9 +137,7 @@ class FmhaV3Diag:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
tAsQ = tAsQ[(None,0,None,0)]; tAgQ = tAgQ[(None,0,None,0)]
|
||||
tBsK = tBsK[(None,None,0,0)]; tBgK = tBgK[(None,None,0,0)]
|
||||
tVsV = tVsV[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
@@ -167,15 +167,13 @@ class FmhaV3StageC:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# Q: GMEM iter at mode 1 (from QK MMA A-partition, layout Q,D,L)
|
||||
# K: GMEM iter at mode 1 (from QK MMA B-partition, layout K,D,L)
|
||||
# V: GMEM iter at mode 2 (from PV MMA B-partition, layout D,K,L)
|
||||
# Must keep GMEM iter dimension FREE for multi-tile KV loads.
|
||||
# CUTLASS reference slices: tQgQ[None,None,0,0], tKgK[None,None,0,0], tVgV[None,0,None,0]
|
||||
# Both GMEM and SMEM sides must be sliced consistently for cute.copy rank matching.
|
||||
tAsQ = tAsQ[(None,0,None,0)]; tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile only
|
||||
tBsK = tBsK[(None,None,0,0)]; tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter)
|
||||
tVsV = tVsV[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter)
|
||||
# Only slice GMEM tensors (SMEM tensors from tma_partition are already 2D).
|
||||
# K from QK MMA: GMEM iter at mode 1 → slice [(None,None,0,0)] keeps modes 0,1
|
||||
# V from PV MMA: GMEM iter at mode 2 → slice [(None,0,None,0)] keeps modes 0,2
|
||||
# Q: 1 tile only, original slice is fine.
|
||||
tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode GMEM iter to 0
|
||||
tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free for kvh.count
|
||||
tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free for kvh.count
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
Reference in New Issue
Block a user