FIX: Use full 8D indexing for tBgK/tVgV — mode 4 is the GMEM tile dim

This commit is contained in:
2026-05-22 21:21:23 +00:00
parent 2a9f764f8b
commit bb92af5b0c

View File

@@ -179,7 +179,12 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,0,None,0)]
# TMA source tensor slices: keep the GMEM tile dimension (mode 4) free
# tBgK shape: (1, 1, 1, 1, 2, 1, 1, 1) — 8 modes, mode 4 = kv_tiles
# tVgV shape: (1, 1, 1, 1, 2, 1, 1, 1) — 8 modes, mode 4 = kv_tiles
tAgQ = tAgQ[(None,0,None,0)]
tBgK = tBgK[(None,None,None,None,None,None,None,None)] # No-op, use full indexing in copy
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)