FIX: 8-None no-op pre-slice opens full TMA coordinate space (8 dims)

The tma_partition output has 8 TMA coordinate dimensions, not 4.
The Python-visible shape shows 4 modes, but the TMA descriptor uses
8 coordinates. Without the 8-None no-op pre-slice, modes 4-7 are
collapsed and the GMEM tile axis (mode 4) is pinned to 0.

Pattern that works (confirmed on B200 at n=256 in diag test):
  tBgK = tBgK[(None,None,None,None,None,None,None,None)]  # open 8D
  cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...)

The old 4-mode indexing tBgK[(None,None,kt,0)] fails with
'rank mismatch: got 2 and 1' because slicing a 4-mode tensor
produces wrong rank for the TMA coordinate space.

Matches working diag test test_fmha_v3_diag.py exactly.
This commit is contained in:
2026-05-22 23:18:40 +00:00
parent 9c5adcee46
commit 30eaba39aa
3 changed files with 73 additions and 67 deletions

View File

@@ -1,63 +1,68 @@
# DSV4 Inference Kernel
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Coordinate Space ⚠️⚠️⚠️
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**:
After `cpasync.tma_partition()`, the output GMEM tensor has **8 TMA coordinate dimensions**:
```
tBgK shape: (((64, 128), 1), ?, ?, ?)
mode 0 1 2 3
tBgK TMA coord space: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
0 1 2 3 4 5 6 7
```
**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
**Mode 4 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
The Python-visible shape only shows 4 modes, but the TMA coordinate space is 8-dimensional. You MUST apply an 8-None no-op pre-slice to open the full coordinate space before indexing.
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
```python
# ❌❌❌ PRE-SLICE SETS MODE 2 TO 0 ❌❌❌
# The (None, None, 0, 0) slice passes 0 for mode 2, pinning the
# GMEM tile dim to tile 0. The TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
# ❌❌❌ 4-MODE PRE-SLICE COLLAPSES THE GMEM TILE AXIS ❌❌❌
# The (None, None, 0, 0) slice only addresses 4 of 8 TMA coord dims.
# Modes 4-7 get collapsed to coordinate 0. TMA ALWAYS reads tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG!
# The copy "works" but kv_coord indexes mode 1 (size 1), so
# every coordinate maps to the same TMA descriptor.
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord is ignored!
```
### THE RIGHT WAY (what actually works):
### THE RIGHT WAY (what actually works — confirmed on B200 at n=256):
```python
# ✅ Do NOT pre-slice. Index all 4 modes explicitly in cute.copy.
# Put kt at mode 2 (the GMEM tile dim) and 0 elsewhere.
# ✅ 8-None no-op pre-slice opens the full TMA coordinate space
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...)
# ^^ MODE 2 — THE GMEM TILE DIM
# ✅ Index mode 4 (the GMEM tile dim) in the copy call
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
# ^^ MODE 4 — THE GMEM TILE DIM
```
### WHY THIS IS SO INSIDIOUS
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0.
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently collapses modes 4-7.
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 4 is size 1, so the bug is invisible.
3. **All multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
### THE LESSON
**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors if you need to index a mode that the slice would collapse.**
**TMA tensors use a coordinate space, not a pointer space.** The TMA instruction at PTX level takes integer coordinates (`crd0, crd1, crd2, ...`), not pointers. CuTeDSL's `tma_partition` produces a tensor whose layout maps logical coordinates to TMA coordinate tuples. When you pre-slice with fewer dimensions than the TMA descriptor expects, the extra coordinate dimensions get collapsed to 0.
**The 8-None no-op pre-slice is mandatory for multi-tile TMA.** Without it, the GMEM tile axis (mode 4) is invisible and unindexable.
```python
# After tma_partition, ALWAYS verify the shape:
print(f"tBgK shape: {cute.shape(tBgK)}")
print(f"Mode 2 size: {cute.size(tBgK, mode=[2])}")
# After tma_partition, ALWAYS apply the 8-None no-op pre-slice:
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
tVgV = tVgV[(None, None, None, None, None, None, None, None)]
# Index all 4 modes explicitly, kt at mode 2
cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...)
# Then index mode 4 in cute.copy:
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
```
**IF YOU PRE-SLICE AND PIN THE GMEM TILE MODE TO 0, MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
**IF YOU SKIP THE 8-NONE PRE-SLICE, MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
---
@@ -289,19 +294,20 @@ What it does:
### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug)
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**:
After `cpasync.tma_partition()`, the output GMEM tensor has **8 TMA coordinate dimensions**:
```
tBgK shape: (((64, 128), 1), ?, ?, ?)
mode 0 1 2 3
tBgK TMA coord space: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
0 1 2 3 4 5 6 7
```
**Mode 2 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` set mode 2 to 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
**Mode 4 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` collapsed modes 4-7 to coordinate 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
**The fix:** Do not pre-slice. Index all 4 modes explicitly in `cute.copy`, putting `kt` at mode 2:
**The fix:** 8-None no-op pre-slice + 8-mode indexing with `kt` at mode 4:
```python
cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...)
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
```
**Results after fix:**
@@ -341,7 +347,7 @@ Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Ep
1. `vectorize=True` loops: ONLY load/store/print
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row
3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
4. `tBgK`/`tVgV` have 4 modes after tma_partition — mode 2 is GMEM tile dim, must index all 4 explicitly
4. `tBgK`/`tVgV` have 8 TMA coord dims after tma_partition — 8-None no-op pre-slice required, mode 4 is GMEM tile dim
5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0
6. `softmax_done_bar` NamedBarrier is reusable across tiles

View File

@@ -14,19 +14,18 @@ Three structural rules learned the hard way:
`utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what
the CUTLASS Blackwell FMHA reference uses in `correction_rescale`.
(C) tma_partition produces a 4-mode tensor for tBgK/tVgV, not a 2-mode one. After
(C) tma_partition produces a tensor with 8 TMA coordinate dimensions, but only
4 are visible in the Python shape. After
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay,
group_modes(sK,0,3),
group_modes(tCgK,0,3))
`tBgK` shape is (((64,128),1),?,?,?) with 4 modes. Mode 2 is the
GMEM-tile iteration axis. Pre-slicing with `tBgK[(None,None,0,0)]`
sets mode 2 to 0, so every TMA copy reads tile 0 regardless of
what's passed at mode 1. The bug pretends to be a JIT issue:
dynamic coords seem to be "constant-folded" because the axis they
vary along is pinned to 0 by the pre-slice.
Fix: do not pre-slice. Index all 4 modes explicitly in the producer's
`cute.copy`, putting `kt` at mode 2 and `0` or `None` everywhere else.
`tBgK` has 8 TMA coord modes: (1,1,1,1,n_kv_tiles,1,1,1).
Mode 4 is the GMEM-tile iteration axis.
Pre-slicing with `tBgK[(None,None,0,0)]` collapses the GMEM tile axis to
coordinate 0, so every TMA copy reads tile 0 regardless of the coord
value passed. The 8-None no-op pre-slice opens the full TMA coord space.
Fix: tBgK = tBgK[(None,None,None,None,None,None,None,None)], then
cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...)
Kernel structure:
@@ -186,13 +185,15 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# NOTE: after tma_partition, tBgK has 4 modes: (((64,128),1),?,?,?).
# Mode 2 is the GMEM tile-iteration axis (size = n_kv_tiles).
# We previously pre-sliced like tBgK[(None,None,0,0)] which set mode 2
# to 0, so TMA always read from tile 0. Fix: no pre-slice — index
# all 4 modes explicitly in cute.copy, putting kt at mode 2.
# tVgV similarly has 4 modes with mode 2 as the GMEM tile dim.
# tAgQ is fine with 4-mode slice (Q has only 1 tile).
# NOTE: after tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions.
# Shape is (1,1,1,1,n_kv_tiles,1,1,1) in TMA coord space.
# Mode 4 is the GMEM tile-iteration axis.
# The no-op 8-None slice opens up the full TMA coordinate space.
# Without it, only 4 modes are visible and 8-mode indexing fails.
# The old (None,None,0,0) pre-slice collapsed the GMEM tile axis to 0.
tAgQ = tAgQ[(None,0,None,0)]
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -216,14 +217,13 @@ class FmhaV3StageCMulti:
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp — fully unrolled =====
# The original pre-slice tBgK[(None,None,0,0)] pinned mode 2 (the
# GMEM tile dim) to 0, forcing all TMA reads to tile 0. With no
# pre-slice, we index all 4 modes explicitly in cute.copy, putting
# kt at mode 2. The pipeline's acquire/release machinery still
# tracks the kv_stage ring buffer dynamically at runtime, so the
# producer correctly blocks on consumer release when n_kv_tiles >
# kv_stage. The unroll only flattens the LOOP control flow, not
# the synchronization.
# The old pre-slice tBgK[(None,None,0,0)] collapsed the 8-mode TMA
# coordinate space to 4 modes, pinning mode 4 (GMEM tile dim) to 0.
# The 8-None no-op pre-slice opens the full TMA coord space so we
# can index mode 4 with kt in cute.copy. The pipeline's
# acquire/release machinery still tracks the kv_stage ring buffer
# dynamically at runtime, so the producer correctly blocks on
# consumer release when n_kv_tiles > kv_stage.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, 0, 0, 0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
@@ -231,20 +231,18 @@ class FmhaV3StageCMulti:
kvp.reset()
for kt in cutlass.range_constexpr(self.n_kv_tiles):
kvh = kvp.acquire_and_advance()
# tBgK has 4 modes after tma_partition: (((64,128),1),?,?,?).
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
# The old pre-slice (None,None,0,0) set mode 2 to 0, forcing
# all TMA reads to tile 0. With no pre-slice, we index all
# 4 modes explicitly, putting kt at mode 2.
# 8-mode TMA indexing: mode 4 = GMEM tile axis (size n_kv_tiles).
# The 8-None pre-slice above opened the full coord space.
# kt at mode 4 indexes the correct KV tile in GMEM.
cute.copy(
tma_k,
tBgK[(None, None, kt, 0)],
tBgK[None, None, None, None, kt, None, None, None],
tBsK[(None, kvh.index)],
tma_bar_ptr=kvh.barrier,
)
cute.copy(
tma_v,
tVgV[(None, 0, kt, 0)],
tVgV[None, None, None, None, kt, None, None, None],
tVsV[(None, kvh.index)],
tma_bar_ptr=kvh.barrier,
)

View File

@@ -179,11 +179,13 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# After tma_partition, tBgK/tVgV have 4 modes: (((64,128),1),?,?,?).
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
# Do NOT pre-slice — index all 4 modes explicitly in cute.copy.
# After tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions.
# Mode 4 is the GMEM tile iteration axis (size = n_kv_tiles).
# 8-None no-op pre-slice opens the full TMA coord space.
# tAgQ is fine with 4-mode slice (Q has only 1 tile).
tAgQ = tAgQ[(None,0,None,0)]
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -221,8 +223,8 @@ class FmhaV3StageCMulti:
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, None, kt, 0)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, 0, kt, 0)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()