FIX: (None,0,None,0) pre-slice keeps KV tile axis (mode 2) free

tBgK has 4 modes: (V_grouped, ?, KV_tiles, ?). Mode 2 is the GMEM tile dim.
Old (None,None,0,0) kept modes 0,1 free → mode 2 collapsed to 0 → always tile 0.
8-None no-op slice FAILS — tensor is 4-mode, not 8-mode, at JIT level.

Fix: (None,0,None,0) keeps modes 0,2 free → 2D tensor.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.

Matches CUTLASS reference FMHA pattern:
  tKgK = tKgK_kdl[None, None, 0, batch]
  cute.copy(tma_k, tKgK[None, kv_coord], ...)
This commit is contained in:
2026-05-22 23:25:40 +00:00
parent 39744ec467
commit 00332825cd
3 changed files with 37 additions and 43 deletions

View File

@@ -14,18 +14,14 @@ Three structural rules learned the hard way:
`utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what
the CUTLASS Blackwell FMHA reference uses in `correction_rescale`.
(C) tma_partition produces a tensor with 8 TMA coordinate dimensions, but only
4 are visible in the Python shape. After
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay,
group_modes(sK,0,3),
group_modes(tCgK,0,3))
`tBgK` has 8 TMA coord modes: (1,1,1,1,n_kv_tiles,1,1,1).
Mode 4 is the GMEM-tile iteration axis.
Pre-slicing with `tBgK[(None,None,0,0)]` collapses the GMEM tile axis to
coordinate 0, so every TMA copy reads tile 0 regardless of the coord
value passed. The 8-None no-op pre-slice opens the full TMA coord space.
Fix: tBgK = tBgK[(None,None,None,None,None,None,None,None)], then
cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...)
(C) tma_partition produces a 4-mode tensor: (V_grouped, ?, KV_tiles, ?).
Mode 2 is the GMEM-tile iteration axis. Pre-slicing with
`tBgK[(None,None,0,0)]` keeps modes 0,1 free but sets mode 2 to 0,
collapsing the KV-tile axis so TMA always reads tile 0.
Fix: pre-slice with `tBgK[(None,0,None,0)]` to keep modes 0,2 free,
then `cute.copy(tma_k, tBgK[None, kt], ...)` indexes mode 1 (the
surviving KV_tiles mode from the original mode 2) with kt.
This matches the CUTLASS reference FMHA pattern.
Kernel structure:
@@ -185,15 +181,17 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# NOTE: after tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions.
# Shape is (1,1,1,1,n_kv_tiles,1,1,1) in TMA coord space.
# Mode 4 is the GMEM tile-iteration axis.
# The no-op 8-None slice opens up the full TMA coordinate space.
# Without it, only 4 modes are visible and 8-mode indexing fails.
# The old (None,None,0,0) pre-slice collapsed the GMEM tile axis to 0.
# NOTE: after tma_partition, tBgK has 4 modes: (V_grouped, ?, KV_tiles, ?).
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
# The old pre-slice (None,None,0,0) kept modes 0,1 free and set mode 2 to 0,
# collapsing the KV tile axis. Fix: keep modes 0,2 free with (None,0,None,0),
# then index with [None, kt] in cute.copy. This matches the CUTLASS reference
# pattern: 2D pre-slice, 2-mode indexing in copy.
# tVgV similarly has KV tiles at mode 2.
# tAgQ: (None,0,None,0) works because Q has only 1 tile.
tAgQ = tAgQ[(None,0,None,0)]
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
tBgK = tBgK[(None,0,None,0)]
tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -217,13 +215,10 @@ class FmhaV3StageCMulti:
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp — fully unrolled =====
# The old pre-slice tBgK[(None,None,0,0)] collapsed the 8-mode TMA
# coordinate space to 4 modes, pinning mode 4 (GMEM tile dim) to 0.
# The 8-None no-op pre-slice opens the full TMA coord space so we
# can index mode 4 with kt in cute.copy. The pipeline's
# acquire/release machinery still tracks the kv_stage ring buffer
# dynamically at runtime, so the producer correctly blocks on
# consumer release when n_kv_tiles > kv_stage.
# The old pre-slice (None,None,0,0) kept modes 0,1 free and set
# mode 2 (KV tiles) to 0. Fix: (None,0,None,0) keeps modes 0,2 free.
# Then [None, kt] indexes the KV tile axis correctly.
# Matches the CUTLASS reference FMHA TMA indexing pattern.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, 0, 0, 0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
@@ -231,18 +226,17 @@ class FmhaV3StageCMulti:
kvp.reset()
for kt in cutlass.range_constexpr(self.n_kv_tiles):
kvh = kvp.acquire_and_advance()
# 8-mode TMA indexing: mode 4 = GMEM tile axis (size n_kv_tiles).
# The 8-None pre-slice above opened the full coord space.
# kt at mode 4 indexes the correct KV tile in GMEM.
# After (None,0,None,0) pre-slice, tBgK is 2D: (V_grouped, KV_tiles).
# [None, kt] indexes mode 1 (KV_tiles) with kt.
cute.copy(
tma_k,
tBgK[None, None, None, None, kt, None, None, None],
tBgK[None, kt],
tBsK[(None, kvh.index)],
tma_bar_ptr=kvh.barrier,
)
cute.copy(
tma_v,
tVgV[None, None, None, None, kt, None, None, None],
tVgV[None, kt],
tVsV[(None, kvh.index)],
tma_bar_ptr=kvh.barrier,
)

View File

@@ -140,13 +140,13 @@ class FmhaV3Diag:
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# =====================================================================
# ⚠️⚠️⚠️ CRITICAL: TMA PARTITION TENSOR MODE ORDERING ⚠️⚠️⚠️
# After tma_partition, tBgK/tVgV have 8 modes: (1,1,1,1,n_kv_tiles,1,1,1)
# After tma_partition, tBgK/tVgV have 4 modes: (V_grouped, ?, KV_tiles, ?)
# Mode 4 is the GMEM tile dimension. DO NOT pre-slice to fewer modes!
# See README.md for details.
# =====================================================================
tAgQ = tAgQ[(None,0,None,0)]
tBgK = tBgK[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice!
tVgV = tVgV[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice!
tBgK = tBgK[(None,0,None,0)]
tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -179,10 +179,10 @@ class FmhaV3Diag:
kv_coord = Int32(0 + 0)
for kt in cutlass.range(self.n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
# ⚠️ CRITICAL: kv_coord indexes MODE 4 of 8-mode tBgK/tVgV.
# kv_coord indexes the surviving mode 1 (from original mode 2) of 2D tBgK/tVgV.
# Using (None, kv_coord) on a pre-sliced 4-mode tensor SILENTLY BREAKS multi-tile!
cute.copy(tma_k, tBgK[None, None, None, None, kv_coord, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[None, None, None, None, kv_coord, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[None, kv_coord], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[None, kv_coord], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord += 1
pk = cutlass.Boolean(1)
kvp.tail()

View File

@@ -179,13 +179,13 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# After tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions.
# After tma_partition, tBgK has 4 modes: (V_grouped, ?, KV_tiles, ?).
# Mode 4 is the GMEM tile iteration axis (size = n_kv_tiles).
# 8-None no-op pre-slice opens the full TMA coord space.
# tAgQ is fine with 4-mode slice (Q has only 1 tile).
tAgQ = tAgQ[(None,0,None,0)]
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
tBgK = tBgK[(None,0,None,0)]
tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -223,8 +223,8 @@ class FmhaV3StageCMulti:
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[None, kt], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[None, kt], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()