FIX: 8-None no-op pre-slice opens full TMA coordinate space (8 dims)
The tma_partition output has 8 TMA coordinate dimensions, not 4. The Python-visible shape shows 4 modes, but the TMA descriptor uses 8 coordinates. Without the 8-None no-op pre-slice, modes 4-7 are collapsed and the GMEM tile axis (mode 4) is pinned to 0. Pattern that works (confirmed on B200 at n=256 in diag test): tBgK = tBgK[(None,None,None,None,None,None,None,None)] # open 8D cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...) The old 4-mode indexing tBgK[(None,None,kt,0)] fails with 'rank mismatch: got 2 and 1' because slicing a 4-mode tensor produces wrong rank for the TMA coordinate space. Matches working diag test test_fmha_v3_diag.py exactly.
This commit is contained in:
66
README.md
66
README.md
@@ -1,63 +1,68 @@
|
||||
# DSV4 Inference Kernel
|
||||
|
||||
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
|
||||
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Coordinate Space ⚠️⚠️⚠️
|
||||
|
||||
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
|
||||
|
||||
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**:
|
||||
After `cpasync.tma_partition()`, the output GMEM tensor has **8 TMA coordinate dimensions**:
|
||||
|
||||
```
|
||||
tBgK shape: (((64, 128), 1), ?, ?, ?)
|
||||
mode 0 1 2 3
|
||||
tBgK TMA coord space: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
|
||||
0 1 2 3 4 5 6 7
|
||||
```
|
||||
|
||||
**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
|
||||
**Mode 4 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
|
||||
|
||||
The Python-visible shape only shows 4 modes, but the TMA coordinate space is 8-dimensional. You MUST apply an 8-None no-op pre-slice to open the full coordinate space before indexing.
|
||||
|
||||
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
|
||||
|
||||
```python
|
||||
# ❌❌❌ PRE-SLICE SETS MODE 2 TO 0 ❌❌❌
|
||||
# The (None, None, 0, 0) slice passes 0 for mode 2, pinning the
|
||||
# GMEM tile dim to tile 0. The TMA ALWAYS reads from tile 0.
|
||||
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
|
||||
# ❌❌❌ 4-MODE PRE-SLICE COLLAPSES THE GMEM TILE AXIS ❌❌❌
|
||||
# The (None, None, 0, 0) slice only addresses 4 of 8 TMA coord dims.
|
||||
# Modes 4-7 get collapsed to coordinate 0. TMA ALWAYS reads tile 0.
|
||||
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG!
|
||||
|
||||
# The copy "works" but kv_coord indexes mode 1 (size 1), so
|
||||
# every coordinate maps to the same TMA descriptor.
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord is ignored!
|
||||
```
|
||||
|
||||
### THE RIGHT WAY (what actually works):
|
||||
### THE RIGHT WAY (what actually works — confirmed on B200 at n=256):
|
||||
|
||||
```python
|
||||
# ✅ Do NOT pre-slice. Index all 4 modes explicitly in cute.copy.
|
||||
# Put kt at mode 2 (the GMEM tile dim) and 0 elsewhere.
|
||||
# ✅ 8-None no-op pre-slice opens the full TMA coordinate space
|
||||
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
|
||||
|
||||
cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...)
|
||||
# ^^ MODE 2 — THE GMEM TILE DIM
|
||||
# ✅ Index mode 4 (the GMEM tile dim) in the copy call
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
|
||||
# ^^ MODE 4 — THE GMEM TILE DIM
|
||||
```
|
||||
|
||||
### WHY THIS IS SO INSIDIOUS
|
||||
|
||||
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0.
|
||||
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
|
||||
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently collapses modes 4-7.
|
||||
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 4 is size 1, so the bug is invisible.
|
||||
3. **All multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
|
||||
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
|
||||
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
|
||||
|
||||
### THE LESSON
|
||||
|
||||
**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors if you need to index a mode that the slice would collapse.**
|
||||
**TMA tensors use a coordinate space, not a pointer space.** The TMA instruction at PTX level takes integer coordinates (`crd0, crd1, crd2, ...`), not pointers. CuTeDSL's `tma_partition` produces a tensor whose layout maps logical coordinates to TMA coordinate tuples. When you pre-slice with fewer dimensions than the TMA descriptor expects, the extra coordinate dimensions get collapsed to 0.
|
||||
|
||||
**The 8-None no-op pre-slice is mandatory for multi-tile TMA.** Without it, the GMEM tile axis (mode 4) is invisible and unindexable.
|
||||
|
||||
```python
|
||||
# After tma_partition, ALWAYS verify the shape:
|
||||
print(f"tBgK shape: {cute.shape(tBgK)}")
|
||||
print(f"Mode 2 size: {cute.size(tBgK, mode=[2])}")
|
||||
# After tma_partition, ALWAYS apply the 8-None no-op pre-slice:
|
||||
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
|
||||
tVgV = tVgV[(None, None, None, None, None, None, None, None)]
|
||||
|
||||
# Index all 4 modes explicitly, kt at mode 2
|
||||
cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...)
|
||||
# Then index mode 4 in cute.copy:
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
|
||||
```
|
||||
|
||||
**IF YOU PRE-SLICE AND PIN THE GMEM TILE MODE TO 0, MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
|
||||
**IF YOU SKIP THE 8-NONE PRE-SLICE, MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
|
||||
|
||||
---
|
||||
|
||||
@@ -289,19 +294,20 @@ What it does:
|
||||
|
||||
### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug)
|
||||
|
||||
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**:
|
||||
After `cpasync.tma_partition()`, the output GMEM tensor has **8 TMA coordinate dimensions**:
|
||||
|
||||
```
|
||||
tBgK shape: (((64, 128), 1), ?, ?, ?)
|
||||
mode 0 1 2 3
|
||||
tBgK TMA coord space: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
|
||||
0 1 2 3 4 5 6 7
|
||||
```
|
||||
|
||||
**Mode 2 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` set mode 2 to 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
|
||||
**Mode 4 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` collapsed modes 4-7 to coordinate 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
|
||||
|
||||
**The fix:** Do not pre-slice. Index all 4 modes explicitly in `cute.copy`, putting `kt` at mode 2:
|
||||
**The fix:** 8-None no-op pre-slice + 8-mode indexing with `kt` at mode 4:
|
||||
|
||||
```python
|
||||
cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...)
|
||||
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
|
||||
```
|
||||
|
||||
**Results after fix:**
|
||||
@@ -341,7 +347,7 @@ Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Ep
|
||||
1. `vectorize=True` loops: ONLY load/store/print
|
||||
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row
|
||||
3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
|
||||
4. `tBgK`/`tVgV` have 4 modes after tma_partition — mode 2 is GMEM tile dim, must index all 4 explicitly
|
||||
4. `tBgK`/`tVgV` have 8 TMA coord dims after tma_partition — 8-None no-op pre-slice required, mode 4 is GMEM tile dim
|
||||
5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0
|
||||
6. `softmax_done_bar` NamedBarrier is reusable across tiles
|
||||
|
||||
|
||||
@@ -14,19 +14,18 @@ Three structural rules learned the hard way:
|
||||
`utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what
|
||||
the CUTLASS Blackwell FMHA reference uses in `correction_rescale`.
|
||||
|
||||
(C) tma_partition produces a 4-mode tensor for tBgK/tVgV, not a 2-mode one. After
|
||||
(C) tma_partition produces a tensor with 8 TMA coordinate dimensions, but only
|
||||
4 are visible in the Python shape. After
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay,
|
||||
group_modes(sK,0,3),
|
||||
group_modes(tCgK,0,3))
|
||||
`tBgK` shape is (((64,128),1),?,?,?) with 4 modes. Mode 2 is the
|
||||
GMEM-tile iteration axis. Pre-slicing with `tBgK[(None,None,0,0)]`
|
||||
sets mode 2 to 0, so every TMA copy reads tile 0 regardless of
|
||||
what's passed at mode 1. The bug pretends to be a JIT issue:
|
||||
dynamic coords seem to be "constant-folded" because the axis they
|
||||
vary along is pinned to 0 by the pre-slice.
|
||||
|
||||
Fix: do not pre-slice. Index all 4 modes explicitly in the producer's
|
||||
`cute.copy`, putting `kt` at mode 2 and `0` or `None` everywhere else.
|
||||
`tBgK` has 8 TMA coord modes: (1,1,1,1,n_kv_tiles,1,1,1).
|
||||
Mode 4 is the GMEM-tile iteration axis.
|
||||
Pre-slicing with `tBgK[(None,None,0,0)]` collapses the GMEM tile axis to
|
||||
coordinate 0, so every TMA copy reads tile 0 regardless of the coord
|
||||
value passed. The 8-None no-op pre-slice opens the full TMA coord space.
|
||||
Fix: tBgK = tBgK[(None,None,None,None,None,None,None,None)], then
|
||||
cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...)
|
||||
|
||||
Kernel structure:
|
||||
|
||||
@@ -186,13 +185,15 @@ class FmhaV3StageCMulti:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# NOTE: after tma_partition, tBgK has 4 modes: (((64,128),1),?,?,?).
|
||||
# Mode 2 is the GMEM tile-iteration axis (size = n_kv_tiles).
|
||||
# We previously pre-sliced like tBgK[(None,None,0,0)] which set mode 2
|
||||
# to 0, so TMA always read from tile 0. Fix: no pre-slice — index
|
||||
# all 4 modes explicitly in cute.copy, putting kt at mode 2.
|
||||
# tVgV similarly has 4 modes with mode 2 as the GMEM tile dim.
|
||||
# tAgQ is fine with 4-mode slice (Q has only 1 tile).
|
||||
# NOTE: after tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions.
|
||||
# Shape is (1,1,1,1,n_kv_tiles,1,1,1) in TMA coord space.
|
||||
# Mode 4 is the GMEM tile-iteration axis.
|
||||
# The no-op 8-None slice opens up the full TMA coordinate space.
|
||||
# Without it, only 4 modes are visible and 8-mode indexing fails.
|
||||
# The old (None,None,0,0) pre-slice collapsed the GMEM tile axis to 0.
|
||||
tAgQ = tAgQ[(None,0,None,0)]
|
||||
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
|
||||
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -216,14 +217,13 @@ class FmhaV3StageCMulti:
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ===== TMA LOAD warp — fully unrolled =====
|
||||
# The original pre-slice tBgK[(None,None,0,0)] pinned mode 2 (the
|
||||
# GMEM tile dim) to 0, forcing all TMA reads to tile 0. With no
|
||||
# pre-slice, we index all 4 modes explicitly in cute.copy, putting
|
||||
# kt at mode 2. The pipeline's acquire/release machinery still
|
||||
# tracks the kv_stage ring buffer dynamically at runtime, so the
|
||||
# producer correctly blocks on consumer release when n_kv_tiles >
|
||||
# kv_stage. The unroll only flattens the LOOP control flow, not
|
||||
# the synchronization.
|
||||
# The old pre-slice tBgK[(None,None,0,0)] collapsed the 8-mode TMA
|
||||
# coordinate space to 4 modes, pinning mode 4 (GMEM tile dim) to 0.
|
||||
# The 8-None no-op pre-slice opens the full TMA coord space so we
|
||||
# can index mode 4 with kt in cute.copy. The pipeline's
|
||||
# acquire/release machinery still tracks the kv_stage ring buffer
|
||||
# dynamically at runtime, so the producer correctly blocks on
|
||||
# consumer release when n_kv_tiles > kv_stage.
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, 0, 0, 0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
@@ -231,20 +231,18 @@ class FmhaV3StageCMulti:
|
||||
kvp.reset()
|
||||
for kt in cutlass.range_constexpr(self.n_kv_tiles):
|
||||
kvh = kvp.acquire_and_advance()
|
||||
# tBgK has 4 modes after tma_partition: (((64,128),1),?,?,?).
|
||||
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
|
||||
# The old pre-slice (None,None,0,0) set mode 2 to 0, forcing
|
||||
# all TMA reads to tile 0. With no pre-slice, we index all
|
||||
# 4 modes explicitly, putting kt at mode 2.
|
||||
# 8-mode TMA indexing: mode 4 = GMEM tile axis (size n_kv_tiles).
|
||||
# The 8-None pre-slice above opened the full coord space.
|
||||
# kt at mode 4 indexes the correct KV tile in GMEM.
|
||||
cute.copy(
|
||||
tma_k,
|
||||
tBgK[(None, None, kt, 0)],
|
||||
tBgK[None, None, None, None, kt, None, None, None],
|
||||
tBsK[(None, kvh.index)],
|
||||
tma_bar_ptr=kvh.barrier,
|
||||
)
|
||||
cute.copy(
|
||||
tma_v,
|
||||
tVgV[(None, 0, kt, 0)],
|
||||
tVgV[None, None, None, None, kt, None, None, None],
|
||||
tVsV[(None, kvh.index)],
|
||||
tma_bar_ptr=kvh.barrier,
|
||||
)
|
||||
|
||||
@@ -179,11 +179,13 @@ class FmhaV3StageCMulti:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# After tma_partition, tBgK/tVgV have 4 modes: (((64,128),1),?,?,?).
|
||||
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
|
||||
# Do NOT pre-slice — index all 4 modes explicitly in cute.copy.
|
||||
# After tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions.
|
||||
# Mode 4 is the GMEM tile iteration axis (size = n_kv_tiles).
|
||||
# 8-None no-op pre-slice opens the full TMA coord space.
|
||||
# tAgQ is fine with 4-mode slice (Q has only 1 tile).
|
||||
tAgQ = tAgQ[(None,0,None,0)]
|
||||
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
|
||||
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -221,8 +223,8 @@ class FmhaV3StageCMulti:
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, None, kt, 0)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, 0, kt, 0)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user