FIX: (None,0,None,0) for ALL tma_partition outputs — verified shapes on B200

DIAG OUTPUT (n=256, inside @cute.kernel):
  tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?))  — 4 modes
  tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?))  — 4 modes
  tVgV: (((64,128),1), 1, 1, 1)                       — 4 modes

After (None,0,None,0) → keeps modes 0 and 2 free → 2D:
  tAgQ: (((64,128),1), Int32(?))
  tBgK: (((64,128),1), Int32(?))
  tVgV: (((64,128),1), 1)

Then [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles).
tAgQ[(None, Int32(0))] for Q (1 tile, coordinate is always 0).
Removed diag prints from test_fmha_v3.py.
This commit is contained in:
2026-05-22 23:35:55 +00:00
parent a50cb138c8
commit 6db7fd339d
2 changed files with 17 additions and 25 deletions

View File

@@ -14,14 +14,14 @@ Three structural rules learned the hard way:
`utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what
the CUTLASS Blackwell FMHA reference uses in `correction_rescale`.
(C) tma_partition produces a 4-mode tensor: (V_grouped, ?, KV_tiles, ?).
Mode 2 is the GMEM-tile iteration axis. Pre-slicing with
`tBgK[(None,None,0,0)]` keeps modes 0,1 free but sets mode 2 to 0,
collapsing the KV-tile axis so TMA always reads tile 0.
(C) tma_partition produces 4-mode tensors: (((64,128),1), ?, ?, ?).
Mode 2 is the GMEM-tile iteration axis.
Pre-slicing with `tBgK[(None,None,0,0)]` keeps modes 0,1 free but sets
mode 2 to 0, collapsing the KV-tile axis so TMA always reads tile 0.
Fix: pre-slice with `tBgK[(None,0,None,0)]` to keep modes 0,2 free,
then `cute.copy(tma_k, tBgK[None, kt], ...)` indexes mode 1 (the
surviving KV_tiles mode from the original mode 2) with kt.
This matches the CUTLASS reference FMHA pattern.
then `cute.copy(tma_k, tBgK[None, kt], ...)` indexes the surviving
KV_tiles mode. Verified shapes on B200: after (None,0,None,0), all
three tensors become 2D: (((64,128),1), Int32(?)) or (((64,128),1), 1).
Kernel structure:
@@ -181,14 +181,15 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# NOTE: after tma_partition, tAgQ has 2 modes (Q has 1 tile),
# tBgK/tVgV have 4 modes: (V_grouped, ?, KV_tiles, ?).
# Mode 2 of tBgK/tVgV is the GMEM tile iteration axis.
# The old pre-slice (None,None,0,0) kept modes 0,1 free and set
# mode 2 to 0, collapsing the KV tile axis.
# Fix: (None,0,None,0) keeps modes 0,2 free, then [None, kt]
# indexes the surviving KV_tiles dim in cute.copy.
tAgQ = tAgQ[(None,0)]
# SHAPES (from diag test on B200, n=256):
# tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
# tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
# tVgV: (((64,128),1), 1, 1, 1) — 4 modes
# Mode 2 is the GMEM tile iteration axis.
# (None,0,None,0) keeps modes 0 and 2 free → 2D tensor.
# Then [None, kt] indexes the surviving KV_tiles dim.
# The old (None,None,0,0) kept modes 0,1 free → mode 2 (KV tiles) set to 0.
tAgQ = tAgQ[(None,0,None,0)]
tBgK = tBgK[(None,0,None,0)]
tVgV = tVgV[(None,0,None,0)]
@@ -220,7 +221,7 @@ class FmhaV3StageCMulti:
# Matches the CUTLASS reference FMHA TMA indexing pattern.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[0], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
cute.copy(tma_q, tAgQ[None, Int32(0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset()
for kt in cutlass.range_constexpr(self.n_kv_tiles):

View File

@@ -134,16 +134,7 @@ class FmhaV3:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
print(f"DIAG tAgQ: shape={cute.shape(tAgQ)} stride={tAgQ.layout.stride}")
print(f"DIAG tBgK: shape={cute.shape(tBgK)} stride={tBgK.layout.stride}")
print(f"DIAG tVgV: shape={cute.shape(tVgV)} stride={tVgV.layout.stride}")
print(f"DIAG tAsQ: shape={cute.shape(tAsQ)} stride={tAsQ.layout.stride}")
print(f"DIAG tBsK: shape={cute.shape(tBsK)} stride={tBsK.layout.stride}")
print(f"DIAG tVsV: shape={cute.shape(tVsV)} stride={tVsV.layout.stride}")
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
print(f"DIAG tAgQ sliced (None,0,None,0): shape={cute.shape(tAgQ)} stride={tAgQ.layout.stride}")
print(f"DIAG tBgK sliced (None,0,None,0): shape={cute.shape(tBgK)} stride={tBgK.layout.stride}")
print(f"DIAG tVgV sliced (None,0,None,0): shape={cute.shape(tVgV)} stride={tVgV.layout.stride}")
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)