Fix TMA indexing: 4-mode tensors, kt at mode 2 (GMEM tile dim)

This commit is contained in:
2026-05-22 21:51:33 +00:00
parent 61b0501a8b
commit 845ad98b22

View File

@@ -178,14 +178,15 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# CRITICAL: tBgK/tVgV have 8 modes after tma_partition.
# Mode 4 is the GMEM tile iteration axis. Pre-slicing with
# (None,None,0,0) collapses modes 4-7 to 0 TMA always reads tile 0.
# Fix: 8-None no-op slice preserves all modes; (None, kt) in copy
# addresses mode 4 correctly.
# CRITICAL: After tma_partition, tBgK has 4 modes: (((64,128),1),?,?,?)
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
# The old pre-slice (None,None,0,0) set mode 2 to 0, so TMA always
# read tile 0. Fix: don't pre-slice; use 4-mode indexing in cute.copy
# with kt at mode 2.
# tVgV similarly has mode 2 as the GMEM tile dim.
# tAgQ is fine — Q has only 1 tile, no iteration needed.
tAgQ = tAgQ[(None,0,None,0)]
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
# No pre-slice for tBgK/tVgV — we index all 4 modes in cute.copy
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -221,8 +222,9 @@ class FmhaV3StageCMulti:
# correctly addresses mode 4 (GMEM tile dim) in cute.copy.
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance()
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
# 4-mode indexing: mode 2 = GMEM tile dim (n_kv_tiles)
cute.copy(tma_k, tBgK[(None, None, kt, 0)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, 0, kt, 0)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kvp.tail()
# ===== MMA warp =====