Fix TMA indexing: 4-mode tensors, kt at mode 2 (GMEM tile dim)
This commit is contained in:
@@ -178,14 +178,15 @@ class FmhaV3StageCMulti:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# CRITICAL: tBgK/tVgV have 8 modes after tma_partition.
|
||||
# Mode 4 is the GMEM tile iteration axis. Pre-slicing with
|
||||
# (None,None,0,0) collapses modes 4-7 to 0 — TMA always reads tile 0.
|
||||
# Fix: 8-None no-op slice preserves all modes; (None, kt) in copy
|
||||
# addresses mode 4 correctly.
|
||||
# CRITICAL: After tma_partition, tBgK has 4 modes: (((64,128),1),?,?,?)
|
||||
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
|
||||
# The old pre-slice (None,None,0,0) set mode 2 to 0, so TMA always
|
||||
# read tile 0. Fix: don't pre-slice; use 4-mode indexing in cute.copy
|
||||
# with kt at mode 2.
|
||||
# tVgV similarly has mode 2 as the GMEM tile dim.
|
||||
# tAgQ is fine — Q has only 1 tile, no iteration needed.
|
||||
tAgQ = tAgQ[(None,0,None,0)]
|
||||
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
|
||||
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
|
||||
# No pre-slice for tBgK/tVgV — we index all 4 modes in cute.copy
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -221,8 +222,9 @@ class FmhaV3StageCMulti:
|
||||
# correctly addresses mode 4 (GMEM tile dim) in cute.copy.
|
||||
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
|
||||
kvh = kvp.acquire_and_advance()
|
||||
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
# 4-mode indexing: mode 2 = GMEM tile dim (n_kv_tiles)
|
||||
cute.copy(tma_k, tBgK[(None, None, kt, 0)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, 0, kt, 0)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kvp.tail()
|
||||
|
||||
# ===== MMA warp =====
|
||||
|
||||
Reference in New Issue
Block a user