FIX: tma_partition tensors have 4 modes, not 8. Mode 2 is GMEM tile dim.

The 8-mode indexing (tBgK[None,None,None,None,kt,None,None,None]) fails at
JIT compilation with 'coord and shape are weakly congruent' error. The actual
MLIR tensor shape is (((64,128),1),?,?,?) — 4 modes, not 8.

The working fix from commit 845ad98 on the B200 used 4-mode indexing all along:
  tBgK[(None, None, kt, 0)] — mode 2 = GMEM tile dim
  tVgV[(None, 0, kt, 0)] — mode 2 = GMEM tile dim

Updated all files: example10, test_fmha_v3_stage_c, README, docstrings.
This commit is contained in:
2026-05-22 23:08:27 +00:00
parent d7cdf63c58
commit 9c5adcee46
3 changed files with 64 additions and 76 deletions

View File

@@ -4,61 +4,60 @@
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
After `cpasync.tma_partition()`, the output GMEM tensor has **8 modes**, NOT 4:
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**:
```
tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
0 1 2 3 4 5 6 7
tBgK shape: (((64, 128), 1), ?, ?, ?)
mode 0 1 2 3
```
**Mode 4 (0-indexed) is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
```python
# ❌❌❌ THIS ONLY ADDRESSES 4 OF THE 8 MODES ❌❌❌
# Mode 4 (the GMEM tile dim) is NEVER touched by this slice.
# It stays fixed at 0. The TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Only 4 modes addressed!
# ❌❌❌ PRE-SLICE SETS MODE 2 TO 0 ❌❌❌
# The (None, None, 0, 0) slice passes 0 for mode 2, pinning the
# GMEM tile dim to tile 0. The TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
# The copy "works" but kv_coord is meaningless — it indexes mode 1,
# which has size 1, so every coordinate maps to the same TMA descriptor.
# The copy "works" but kv_coord indexes mode 1 (size 1), so
# every coordinate maps to the same TMA descriptor.
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord is ignored!
```
### THE RIGHT WAY (what actually works):
```python
# ✅ Keep ALL 8 modes. Do NOT pre-slice the TMA GMEM tensor.
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
# ✅ Do NOT pre-slice. Index all 4 modes explicitly in cute.copy.
# Put kt at mode 2 (the GMEM tile dim) and 0 elsewhere.
# ✅ Index mode 4 (the GMEM tile dim) in the copy call
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
# ^^ THIS IS MODE 4 — THE GMEM TILE DIM
cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...)
# ^^ MODE 2 — THE GMEM TILE DIM
```
### WHY THIS IS SO INSIDIOUS
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently drops modes 4-7.
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 4 is size 1, so the bug is invisible.
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0.
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
3. **All multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
### THE LESSON
**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors with fewer dimensions than they actually have.**
**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors if you need to index a mode that the slice would collapse.**
```python
# After tma_partition, ALWAYS verify the shape:
print(f"tBgK shape: {cute.shape(tBgK)}") # Should show 8 modes
print(f"n_kv_tiles at mode 4: {cute.size(tBgK, mode=[4])}")
print(f"tBgK shape: {cute.shape(tBgK)}")
print(f"Mode 2 size: {cute.size(tBgK, mode=[2])}")
# Use full indexing in cute.copy, not a pre-sliced 2D view
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
# Index all 4 modes explicitly, kt at mode 2
cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...)
```
**IF YOU SKIP THIS AND PRE-SLICE TO 4 MODES, MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
**IF YOU PRE-SLICE AND PIN THE GMEM TILE MODE TO 0, MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
---
@@ -290,19 +289,19 @@ What it does:
### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug)
After `cpasync.tma_partition()`, the output GMEM tensor has **8 modes**, not 4:
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**:
```
tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
0 1 2 3 4 5 6 7
tBgK shape: (((64, 128), 1), ?, ?, ?)
mode 0 1 2 3
```
**Mode 4 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` only addressed 4 modes — modes 4-7 were implicitly collapsed to coordinate 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
**Mode 2 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` set mode 2 to 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
**The fix:** Do not pre-slice. Index all 8 modes explicitly in `cute.copy`, putting `kt` at mode 4:
**The fix:** Do not pre-slice. Index all 4 modes explicitly in `cute.copy`, putting `kt` at mode 2:
```python
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...)
```
**Results after fix:**
@@ -342,7 +341,7 @@ Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Ep
1. `vectorize=True` loops: ONLY load/store/print
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row
3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
4. `tBgK`/`tVgV` have 8 modes after tma_partition — mode 4 is GMEM tile dim, must index all 8 explicitly
4. `tBgK`/`tVgV` have 4 modes after tma_partition — mode 2 is GMEM tile dim, must index all 4 explicitly
5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0
6. `softmax_done_bar` NamedBarrier is reusable across tiles

View File

@@ -14,20 +14,19 @@ Three structural rules learned the hard way:
`utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what
the CUTLASS Blackwell FMHA reference uses in `correction_rescale`.
(C) tma_partition produces an 8-mode tensor, not a 4-mode one. After
(C) tma_partition produces a 4-mode tensor for tBgK/tVgV, not a 2-mode one. After
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay,
group_modes(sK,0,3),
group_modes(tCgK,0,3))
`tBgK` shape is (1, 1, 1, 1, n_kv_tiles, 1, 1, 1). Mode 4 is the
`tBgK` shape is (((64,128),1),?,?,?) with 4 modes. Mode 2 is the
GMEM-tile iteration axis. Pre-slicing with `tBgK[(None,None,0,0)]`
addresses only 4 modes — the remaining modes (including the KV-tile
axis at mode 4) get implicitly collapsed to coord 0, so every TMA copy
reads tile 0 regardless of what's passed at index 1. The bug pretends
to be a JIT issue: dynamic coords seem to be "constant-folded" because
the only axis they could vary along has stride 0.
sets mode 2 to 0, so every TMA copy reads tile 0 regardless of
what's passed at mode 1. The bug pretends to be a JIT issue:
dynamic coords seem to be "constant-folded" because the axis they
vary along is pinned to 0 by the pre-slice.
Fix: do not pre-slice. Index all 8 modes explicitly in the producer's
`cute.copy`, putting `kt` at mode 4 and `None` (or 0) everywhere else.
Fix: do not pre-slice. Index all 4 modes explicitly in the producer's
`cute.copy`, putting `kt` at mode 2 and `0` or `None` everywhere else.
Kernel structure:
@@ -187,15 +186,13 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# NOTE: after tma_partition, ALL three tensors (tAgQ, tBgK, tVgV)
# have 8 modes, e.g. tBgK shape = (1, 1, 1, 1, n_kv_tiles, 1, 1, 1).
# Mode 4 is the GMEM tile-iteration axis; all other modes are size 1.
# We previously pre-sliced like `tBgK[(None,None,0,0)]` which only
# addressed 4 modes — modes 4..7 got swept up into the trailing 0
# and the KV-tile axis was effectively collapsed to tile 0 always.
# That, not a JIT bug, was why every dynamic coord produced tile-0
# data. With no pre-slice, we index all 8 modes explicitly in the
# producer warp below, putting `kt` at mode 4 and `None` elsewhere.
# NOTE: after tma_partition, tBgK has 4 modes: (((64,128),1),?,?,?).
# Mode 2 is the GMEM tile-iteration axis (size = n_kv_tiles).
# We previously pre-sliced like tBgK[(None,None,0,0)] which set mode 2
# to 0, so TMA always read from tile 0. Fix: no pre-slice — index
# all 4 modes explicitly in cute.copy, putting kt at mode 2.
# tVgV similarly has 4 modes with mode 2 as the GMEM tile dim.
# tAgQ is fine with 4-mode slice (Q has only 1 tile).
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -219,43 +216,35 @@ class FmhaV3StageCMulti:
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp — fully unrolled =====
# Why fully unrolled: original suspicion was a JIT bug propagating
# dynamic coords; actual root cause was a layout bug — the pre-slice
# collapsed all the mode-4 GMEM-tile axis to 0. After tma_partition,
# tBgK and tVgV have 8 modes: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1) where
# mode 4 is the KV-tile iteration axis.
#
# Indexing rule: pass `None` for every mode that's size 1 (passthrough),
# and `kt` for the KV-tile axis at mode 4.
#
# We keep the unroll for correctness confidence. The pipeline's
# acquire/release machinery still tracks the kv_stage ring buffer
# dynamically at runtime, so the producer correctly blocks on
# consumer release when n_kv_tiles > kv_stage. The unroll only
# flattens the LOOP control flow, not the synchronization.
# The original pre-slice tBgK[(None,None,0,0)] pinned mode 2 (the
# GMEM tile dim) to 0, forcing all TMA reads to tile 0. With no
# pre-slice, we index all 4 modes explicitly in cute.copy, putting
# kt at mode 2. The pipeline's acquire/release machinery still
# tracks the kv_stage ring buffer dynamically at runtime, so the
# producer correctly blocks on consumer release when n_kv_tiles >
# kv_stage. The unroll only flattens the LOOP control flow, not
# the synchronization.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
# Q's 4-mode indexing was confirmed working at n=128 cos 0.999998
# before the multi-tile investigation; leave it alone. Only K/V's
# tma_partition output is 8 modes (the KV-tile axis introduces
# the extra modes — Q has no equivalent tile axis since there's
# one Q tile per CTA).
cute.copy(tma_q, tAgQ[(None, 0, 0, 0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset()
for kt in cutlass.range_constexpr(self.n_kv_tiles):
kvh = kvp.acquire_and_advance()
# 8-mode indices, mode 4 = KV-tile axis (size n_kv_tiles).
# Every other mode is size 1, so None is fine.
# tBgK has 4 modes after tma_partition: (((64,128),1),?,?,?).
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
# The old pre-slice (None,None,0,0) set mode 2 to 0, forcing
# all TMA reads to tile 0. With no pre-slice, we index all
# 4 modes explicitly, putting kt at mode 2.
cute.copy(
tma_k,
tBgK[None, None, None, None, kt, None, None, None],
tBgK[(None, None, kt, 0)],
tBsK[(None, kvh.index)],
tma_bar_ptr=kvh.barrier,
)
cute.copy(
tma_v,
tVgV[None, None, None, None, kt, None, None, None],
tVgV[(None, 0, kt, 0)],
tVsV[(None, kvh.index)],
tma_bar_ptr=kvh.barrier,
)

View File

@@ -179,9 +179,9 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# After tma_partition, tBgK/tVgV have 8 modes.
# Mode 4 is the GMEM tile iteration axis (size = n_kv_tiles).
# Do NOT pre-slice — index all 8 modes explicitly in cute.copy.
# After tma_partition, tBgK/tVgV have 4 modes: (((64,128),1),?,?,?).
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
# Do NOT pre-slice — index all 4 modes explicitly in cute.copy.
# tAgQ is fine with 4-mode slice (Q has only 1 tile).
tAgQ = tAgQ[(None,0,None,0)]
@@ -221,8 +221,8 @@ class FmhaV3StageCMulti:
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[(None, None, kt, 0)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, 0, kt, 0)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()