From 30eaba39aa763452d206dbe4c0b59457041e3dcf Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 23:18:40 +0000 Subject: [PATCH] FIX: 8-None no-op pre-slice opens full TMA coordinate space (8 dims) The tma_partition output has 8 TMA coordinate dimensions, not 4. The Python-visible shape shows 4 modes, but the TMA descriptor uses 8 coordinates. Without the 8-None no-op pre-slice, modes 4-7 are collapsed and the GMEM tile axis (mode 4) is pinned to 0. Pattern that works (confirmed on B200 at n=256 in diag test): tBgK = tBgK[(None,None,None,None,None,None,None,None)] # open 8D cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...) The old 4-mode indexing tBgK[(None,None,kt,0)] fails with 'rank mismatch: got 2 and 1' because slicing a 4-mode tensor produces wrong rank for the TMA coordinate space. Matches working diag test test_fmha_v3_diag.py exactly. --- README.md | 66 ++++++++++++++++-------------- tests/fmha_v3_stage_c_example10.py | 62 ++++++++++++++-------------- tests/unit/test_fmha_v3_stage_c.py | 12 +++--- 3 files changed, 73 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 4c168c11..6485d2d1 100644 --- a/README.md +++ b/README.md @@ -1,63 +1,68 @@ # DSV4 Inference Kernel -## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️ +## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Coordinate Space ⚠️⚠️⚠️ **THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.** -After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**: +After `cpasync.tma_partition()`, the output GMEM tensor has **8 TMA coordinate dimensions**: ``` -tBgK shape: (((64, 128), 1), ?, ?, ?) - mode 0 1 2 3 +tBgK TMA coord space: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1) + 0 1 2 3 4 5 6 7 ``` -**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles. +**Mode 4 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles. + +The Python-visible shape only shows 4 modes, but the TMA coordinate space is 8-dimensional. You MUST apply an 8-None no-op pre-slice to open the full coordinate space before indexing. ### THE WRONG WAY (what we did — silently loads from tile 0 forever): ```python -# ❌❌❌ PRE-SLICE SETS MODE 2 TO 0 ❌❌❌ -# The (None, None, 0, 0) slice passes 0 for mode 2, pinning the -# GMEM tile dim to tile 0. The TMA ALWAYS reads from tile 0. -tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0! +# ❌❌❌ 4-MODE PRE-SLICE COLLAPSES THE GMEM TILE AXIS ❌❌❌ +# The (None, None, 0, 0) slice only addresses 4 of 8 TMA coord dims. +# Modes 4-7 get collapsed to coordinate 0. TMA ALWAYS reads tile 0. +tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! # The copy "works" but kv_coord indexes mode 1 (size 1), so # every coordinate maps to the same TMA descriptor. cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord is ignored! ``` -### THE RIGHT WAY (what actually works): +### THE RIGHT WAY (what actually works — confirmed on B200 at n=256): ```python -# ✅ Do NOT pre-slice. Index all 4 modes explicitly in cute.copy. -# Put kt at mode 2 (the GMEM tile dim) and 0 elsewhere. +# ✅ 8-None no-op pre-slice opens the full TMA coordinate space +tBgK = tBgK[(None, None, None, None, None, None, None, None)] -cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...) -# ^^ MODE 2 — THE GMEM TILE DIM +# ✅ Index mode 4 (the GMEM tile dim) in the copy call +cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...) +# ^^ MODE 4 — THE GMEM TILE DIM ``` ### WHY THIS IS SO INSIDIOUS -1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0. -2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible. +1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently collapses modes 4-7. +2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 4 is size 1, so the bug is invisible. 3. **All multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN. 4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone. 5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode. ### THE LESSON -**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors if you need to index a mode that the slice would collapse.** +**TMA tensors use a coordinate space, not a pointer space.** The TMA instruction at PTX level takes integer coordinates (`crd0, crd1, crd2, ...`), not pointers. CuTeDSL's `tma_partition` produces a tensor whose layout maps logical coordinates to TMA coordinate tuples. When you pre-slice with fewer dimensions than the TMA descriptor expects, the extra coordinate dimensions get collapsed to 0. + +**The 8-None no-op pre-slice is mandatory for multi-tile TMA.** Without it, the GMEM tile axis (mode 4) is invisible and unindexable. ```python -# After tma_partition, ALWAYS verify the shape: -print(f"tBgK shape: {cute.shape(tBgK)}") -print(f"Mode 2 size: {cute.size(tBgK, mode=[2])}") +# After tma_partition, ALWAYS apply the 8-None no-op pre-slice: +tBgK = tBgK[(None, None, None, None, None, None, None, None)] +tVgV = tVgV[(None, None, None, None, None, None, None, None)] -# Index all 4 modes explicitly, kt at mode 2 -cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...) +# Then index mode 4 in cute.copy: +cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...) ``` -**IF YOU PRE-SLICE AND PIN THE GMEM TILE MODE TO 0, MULTI-TILE TMA WILL BE SILENTLY BROKEN.** +**IF YOU SKIP THE 8-NONE PRE-SLICE, MULTI-TILE TMA WILL BE SILENTLY BROKEN.** --- @@ -289,19 +294,20 @@ What it does: ### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug) -After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**: +After `cpasync.tma_partition()`, the output GMEM tensor has **8 TMA coordinate dimensions**: ``` -tBgK shape: (((64, 128), 1), ?, ?, ?) - mode 0 1 2 3 +tBgK TMA coord space: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1) + 0 1 2 3 4 5 6 7 ``` -**Mode 2 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` set mode 2 to 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error. +**Mode 4 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` collapsed modes 4-7 to coordinate 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error. -**The fix:** Do not pre-slice. Index all 4 modes explicitly in `cute.copy`, putting `kt` at mode 2: +**The fix:** 8-None no-op pre-slice + 8-mode indexing with `kt` at mode 4: ```python -cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...) +tBgK = tBgK[(None, None, None, None, None, None, None, None)] +cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...) ``` **Results after fix:** @@ -341,7 +347,7 @@ Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Ep 1. `vectorize=True` loops: ONLY load/store/print 2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row 3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop -4. `tBgK`/`tVgV` have 4 modes after tma_partition — mode 2 is GMEM tile dim, must index all 4 explicitly +4. `tBgK`/`tVgV` have 8 TMA coord dims after tma_partition — 8-None no-op pre-slice required, mode 4 is GMEM tile dim 5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0 6. `softmax_done_bar` NamedBarrier is reusable across tiles diff --git a/tests/fmha_v3_stage_c_example10.py b/tests/fmha_v3_stage_c_example10.py index 734980d9..f0614abb 100644 --- a/tests/fmha_v3_stage_c_example10.py +++ b/tests/fmha_v3_stage_c_example10.py @@ -14,19 +14,18 @@ Three structural rules learned the hard way: `utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what the CUTLASS Blackwell FMHA reference uses in `correction_rescale`. -(C) tma_partition produces a 4-mode tensor for tBgK/tVgV, not a 2-mode one. After +(C) tma_partition produces a tensor with 8 TMA coordinate dimensions, but only + 4 are visible in the Python shape. After tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, group_modes(sK,0,3), group_modes(tCgK,0,3)) - `tBgK` shape is (((64,128),1),?,?,?) with 4 modes. Mode 2 is the - GMEM-tile iteration axis. Pre-slicing with `tBgK[(None,None,0,0)]` - sets mode 2 to 0, so every TMA copy reads tile 0 regardless of - what's passed at mode 1. The bug pretends to be a JIT issue: - dynamic coords seem to be "constant-folded" because the axis they - vary along is pinned to 0 by the pre-slice. - - Fix: do not pre-slice. Index all 4 modes explicitly in the producer's - `cute.copy`, putting `kt` at mode 2 and `0` or `None` everywhere else. + `tBgK` has 8 TMA coord modes: (1,1,1,1,n_kv_tiles,1,1,1). + Mode 4 is the GMEM-tile iteration axis. + Pre-slicing with `tBgK[(None,None,0,0)]` collapses the GMEM tile axis to + coordinate 0, so every TMA copy reads tile 0 regardless of the coord + value passed. The 8-None no-op pre-slice opens the full TMA coord space. + Fix: tBgK = tBgK[(None,None,None,None,None,None,None,None)], then + cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...) Kernel structure: @@ -186,13 +185,15 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # NOTE: after tma_partition, tBgK has 4 modes: (((64,128),1),?,?,?). - # Mode 2 is the GMEM tile-iteration axis (size = n_kv_tiles). - # We previously pre-sliced like tBgK[(None,None,0,0)] which set mode 2 - # to 0, so TMA always read from tile 0. Fix: no pre-slice — index - # all 4 modes explicitly in cute.copy, putting kt at mode 2. - # tVgV similarly has 4 modes with mode 2 as the GMEM tile dim. - # tAgQ is fine with 4-mode slice (Q has only 1 tile). + # NOTE: after tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions. + # Shape is (1,1,1,1,n_kv_tiles,1,1,1) in TMA coord space. + # Mode 4 is the GMEM tile-iteration axis. + # The no-op 8-None slice opens up the full TMA coordinate space. + # Without it, only 4 modes are visible and 8-mode indexing fails. + # The old (None,None,0,0) pre-slice collapsed the GMEM tile axis to 0. + tAgQ = tAgQ[(None,0,None,0)] + tBgK = tBgK[(None,None,None,None,None,None,None,None)] + tVgV = tVgV[(None,None,None,None,None,None,None,None)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -216,14 +217,13 @@ class FmhaV3StageCMulti: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp — fully unrolled ===== - # The original pre-slice tBgK[(None,None,0,0)] pinned mode 2 (the - # GMEM tile dim) to 0, forcing all TMA reads to tile 0. With no - # pre-slice, we index all 4 modes explicitly in cute.copy, putting - # kt at mode 2. The pipeline's acquire/release machinery still - # tracks the kv_stage ring buffer dynamically at runtime, so the - # producer correctly blocks on consumer release when n_kv_tiles > - # kv_stage. The unroll only flattens the LOOP control flow, not - # the synchronization. + # The old pre-slice tBgK[(None,None,0,0)] collapsed the 8-mode TMA + # coordinate space to 4 modes, pinning mode 4 (GMEM tile dim) to 0. + # The 8-None no-op pre-slice opens the full TMA coord space so we + # can index mode 4 with kt in cute.copy. The pipeline's + # acquire/release machinery still tracks the kv_stage ring buffer + # dynamically at runtime, so the producer correctly blocks on + # consumer release when n_kv_tiles > kv_stage. if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, 0, 0, 0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) @@ -231,20 +231,18 @@ class FmhaV3StageCMulti: kvp.reset() for kt in cutlass.range_constexpr(self.n_kv_tiles): kvh = kvp.acquire_and_advance() - # tBgK has 4 modes after tma_partition: (((64,128),1),?,?,?). - # Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles). - # The old pre-slice (None,None,0,0) set mode 2 to 0, forcing - # all TMA reads to tile 0. With no pre-slice, we index all - # 4 modes explicitly, putting kt at mode 2. + # 8-mode TMA indexing: mode 4 = GMEM tile axis (size n_kv_tiles). + # The 8-None pre-slice above opened the full coord space. + # kt at mode 4 indexes the correct KV tile in GMEM. cute.copy( tma_k, - tBgK[(None, None, kt, 0)], + tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier, ) cute.copy( tma_v, - tVgV[(None, 0, kt, 0)], + tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier, ) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index e2334e40..224d6b65 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -179,11 +179,13 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # After tma_partition, tBgK/tVgV have 4 modes: (((64,128),1),?,?,?). - # Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles). - # Do NOT pre-slice — index all 4 modes explicitly in cute.copy. + # After tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions. + # Mode 4 is the GMEM tile iteration axis (size = n_kv_tiles). + # 8-None no-op pre-slice opens the full TMA coord space. # tAgQ is fine with 4-mode slice (Q has only 1 tile). tAgQ = tAgQ[(None,0,None,0)] + tBgK = tBgK[(None,None,None,None,None,None,None,None)] + tVgV = tVgV[(None,None,None,None,None,None,None,None)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -221,8 +223,8 @@ class FmhaV3StageCMulti: kvp.reset(); pk = kvp.try_acquire() for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, None, kt, 0)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, 0, kt, 0)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) pk = cutlass.Boolean(1) kvp.tail()