From 9c5adcee46d5c7ea11e24934f0ed197aba5050fb Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 23:08:27 +0000 Subject: [PATCH] FIX: tma_partition tensors have 4 modes, not 8. Mode 2 is GMEM tile dim. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 8-mode indexing (tBgK[None,None,None,None,kt,None,None,None]) fails at JIT compilation with 'coord and shape are weakly congruent' error. The actual MLIR tensor shape is (((64,128),1),?,?,?) — 4 modes, not 8. The working fix from commit 845ad98 on the B200 used 4-mode indexing all along: tBgK[(None, None, kt, 0)] — mode 2 = GMEM tile dim tVgV[(None, 0, kt, 0)] — mode 2 = GMEM tile dim Updated all files: example10, test_fmha_v3_stage_c, README, docstrings. --- README.md | 59 ++++++++++++------------- tests/fmha_v3_stage_c_example10.py | 71 +++++++++++++----------------- tests/unit/test_fmha_v3_stage_c.py | 10 ++--- 3 files changed, 64 insertions(+), 76 deletions(-) diff --git a/README.md b/README.md index 93329b09..4c168c11 100644 --- a/README.md +++ b/README.md @@ -4,61 +4,60 @@ **THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.** -After `cpasync.tma_partition()`, the output GMEM tensor has **8 modes**, NOT 4: +After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**: ``` -tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1) - 0 1 2 3 4 5 6 7 +tBgK shape: (((64, 128), 1), ?, ?, ?) + mode 0 1 2 3 ``` -**Mode 4 (0-indexed) is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles. +**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles. ### THE WRONG WAY (what we did — silently loads from tile 0 forever): ```python -# ❌❌❌ THIS ONLY ADDRESSES 4 OF THE 8 MODES ❌❌❌ -# Mode 4 (the GMEM tile dim) is NEVER touched by this slice. -# It stays fixed at 0. The TMA ALWAYS reads from tile 0. -tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Only 4 modes addressed! +# ❌❌❌ PRE-SLICE SETS MODE 2 TO 0 ❌❌❌ +# The (None, None, 0, 0) slice passes 0 for mode 2, pinning the +# GMEM tile dim to tile 0. The TMA ALWAYS reads from tile 0. +tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0! -# The copy "works" but kv_coord is meaningless — it indexes mode 1, -# which has size 1, so every coordinate maps to the same TMA descriptor. +# The copy "works" but kv_coord indexes mode 1 (size 1), so +# every coordinate maps to the same TMA descriptor. cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord is ignored! ``` ### THE RIGHT WAY (what actually works): ```python -# ✅ Keep ALL 8 modes. Do NOT pre-slice the TMA GMEM tensor. -tBgK = tBgK[(None, None, None, None, None, None, None, None)] +# ✅ Do NOT pre-slice. Index all 4 modes explicitly in cute.copy. +# Put kt at mode 2 (the GMEM tile dim) and 0 elsewhere. -# ✅ Index mode 4 (the GMEM tile dim) in the copy call -cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...) -# ^^ THIS IS MODE 4 — THE GMEM TILE DIM +cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...) +# ^^ MODE 2 — THE GMEM TILE DIM ``` ### WHY THIS IS SO INSIDIOUS -1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently drops modes 4-7. -2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 4 is size 1, so the bug is invisible. +1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0. +2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible. 3. **All multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN. 4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone. 5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode. ### THE LESSON -**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors with fewer dimensions than they actually have.** +**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors if you need to index a mode that the slice would collapse.** ```python # After tma_partition, ALWAYS verify the shape: -print(f"tBgK shape: {cute.shape(tBgK)}") # Should show 8 modes -print(f"n_kv_tiles at mode 4: {cute.size(tBgK, mode=[4])}") +print(f"tBgK shape: {cute.shape(tBgK)}") +print(f"Mode 2 size: {cute.size(tBgK, mode=[2])}") -# Use full indexing in cute.copy, not a pre-sliced 2D view -cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...) +# Index all 4 modes explicitly, kt at mode 2 +cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...) ``` -**IF YOU SKIP THIS AND PRE-SLICE TO 4 MODES, MULTI-TILE TMA WILL BE SILENTLY BROKEN.** +**IF YOU PRE-SLICE AND PIN THE GMEM TILE MODE TO 0, MULTI-TILE TMA WILL BE SILENTLY BROKEN.** --- @@ -290,19 +289,19 @@ What it does: ### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug) -After `cpasync.tma_partition()`, the output GMEM tensor has **8 modes**, not 4: +After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**: ``` -tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1) - 0 1 2 3 4 5 6 7 +tBgK shape: (((64, 128), 1), ?, ?, ?) + mode 0 1 2 3 ``` -**Mode 4 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` only addressed 4 modes — modes 4-7 were implicitly collapsed to coordinate 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error. +**Mode 2 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` set mode 2 to 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error. -**The fix:** Do not pre-slice. Index all 8 modes explicitly in `cute.copy`, putting `kt` at mode 4: +**The fix:** Do not pre-slice. Index all 4 modes explicitly in `cute.copy`, putting `kt` at mode 2: ```python -cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...) +cute.copy(tma_k, tBgK[(None, None, kt, 0)], ...) ``` **Results after fix:** @@ -342,7 +341,7 @@ Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Ep 1. `vectorize=True` loops: ONLY load/store/print 2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row 3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop -4. `tBgK`/`tVgV` have 8 modes after tma_partition — mode 4 is GMEM tile dim, must index all 8 explicitly +4. `tBgK`/`tVgV` have 4 modes after tma_partition — mode 2 is GMEM tile dim, must index all 4 explicitly 5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0 6. `softmax_done_bar` NamedBarrier is reusable across tiles diff --git a/tests/fmha_v3_stage_c_example10.py b/tests/fmha_v3_stage_c_example10.py index 2bc9452d..734980d9 100644 --- a/tests/fmha_v3_stage_c_example10.py +++ b/tests/fmha_v3_stage_c_example10.py @@ -14,20 +14,19 @@ Three structural rules learned the hard way: `utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what the CUTLASS Blackwell FMHA reference uses in `correction_rescale`. -(C) tma_partition produces an 8-mode tensor, not a 4-mode one. After +(C) tma_partition produces a 4-mode tensor for tBgK/tVgV, not a 2-mode one. After tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, group_modes(sK,0,3), group_modes(tCgK,0,3)) - `tBgK` shape is (1, 1, 1, 1, n_kv_tiles, 1, 1, 1). Mode 4 is the + `tBgK` shape is (((64,128),1),?,?,?) with 4 modes. Mode 2 is the GMEM-tile iteration axis. Pre-slicing with `tBgK[(None,None,0,0)]` - addresses only 4 modes — the remaining modes (including the KV-tile - axis at mode 4) get implicitly collapsed to coord 0, so every TMA copy - reads tile 0 regardless of what's passed at index 1. The bug pretends - to be a JIT issue: dynamic coords seem to be "constant-folded" because - the only axis they could vary along has stride 0. + sets mode 2 to 0, so every TMA copy reads tile 0 regardless of + what's passed at mode 1. The bug pretends to be a JIT issue: + dynamic coords seem to be "constant-folded" because the axis they + vary along is pinned to 0 by the pre-slice. - Fix: do not pre-slice. Index all 8 modes explicitly in the producer's - `cute.copy`, putting `kt` at mode 4 and `None` (or 0) everywhere else. + Fix: do not pre-slice. Index all 4 modes explicitly in the producer's + `cute.copy`, putting `kt` at mode 2 and `0` or `None` everywhere else. Kernel structure: @@ -187,15 +186,13 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # NOTE: after tma_partition, ALL three tensors (tAgQ, tBgK, tVgV) - # have 8 modes, e.g. tBgK shape = (1, 1, 1, 1, n_kv_tiles, 1, 1, 1). - # Mode 4 is the GMEM tile-iteration axis; all other modes are size 1. - # We previously pre-sliced like `tBgK[(None,None,0,0)]` which only - # addressed 4 modes — modes 4..7 got swept up into the trailing 0 - # and the KV-tile axis was effectively collapsed to tile 0 always. - # That, not a JIT bug, was why every dynamic coord produced tile-0 - # data. With no pre-slice, we index all 8 modes explicitly in the - # producer warp below, putting `kt` at mode 4 and `None` elsewhere. + # NOTE: after tma_partition, tBgK has 4 modes: (((64,128),1),?,?,?). + # Mode 2 is the GMEM tile-iteration axis (size = n_kv_tiles). + # We previously pre-sliced like tBgK[(None,None,0,0)] which set mode 2 + # to 0, so TMA always read from tile 0. Fix: no pre-slice — index + # all 4 modes explicitly in cute.copy, putting kt at mode 2. + # tVgV similarly has 4 modes with mode 2 as the GMEM tile dim. + # tAgQ is fine with 4-mode slice (Q has only 1 tile). tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -219,43 +216,35 @@ class FmhaV3StageCMulti: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp — fully unrolled ===== - # Why fully unrolled: original suspicion was a JIT bug propagating - # dynamic coords; actual root cause was a layout bug — the pre-slice - # collapsed all the mode-4 GMEM-tile axis to 0. After tma_partition, - # tBgK and tVgV have 8 modes: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1) where - # mode 4 is the KV-tile iteration axis. - # - # Indexing rule: pass `None` for every mode that's size 1 (passthrough), - # and `kt` for the KV-tile axis at mode 4. - # - # We keep the unroll for correctness confidence. The pipeline's - # acquire/release machinery still tracks the kv_stage ring buffer - # dynamically at runtime, so the producer correctly blocks on - # consumer release when n_kv_tiles > kv_stage. The unroll only - # flattens the LOOP control flow, not the synchronization. + # The original pre-slice tBgK[(None,None,0,0)] pinned mode 2 (the + # GMEM tile dim) to 0, forcing all TMA reads to tile 0. With no + # pre-slice, we index all 4 modes explicitly in cute.copy, putting + # kt at mode 2. The pipeline's acquire/release machinery still + # tracks the kv_stage ring buffer dynamically at runtime, so the + # producer correctly blocks on consumer release when n_kv_tiles > + # kv_stage. The unroll only flattens the LOOP control flow, not + # the synchronization. if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() - # Q's 4-mode indexing was confirmed working at n=128 cos 0.999998 - # before the multi-tile investigation; leave it alone. Only K/V's - # tma_partition output is 8 modes (the KV-tile axis introduces - # the extra modes — Q has no equivalent tile axis since there's - # one Q tile per CTA). cute.copy(tma_q, tAgQ[(None, 0, 0, 0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset() for kt in cutlass.range_constexpr(self.n_kv_tiles): kvh = kvp.acquire_and_advance() - # 8-mode indices, mode 4 = KV-tile axis (size n_kv_tiles). - # Every other mode is size 1, so None is fine. + # tBgK has 4 modes after tma_partition: (((64,128),1),?,?,?). + # Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles). + # The old pre-slice (None,None,0,0) set mode 2 to 0, forcing + # all TMA reads to tile 0. With no pre-slice, we index all + # 4 modes explicitly, putting kt at mode 2. cute.copy( tma_k, - tBgK[None, None, None, None, kt, None, None, None], + tBgK[(None, None, kt, 0)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier, ) cute.copy( tma_v, - tVgV[None, None, None, None, kt, None, None, None], + tVgV[(None, 0, kt, 0)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier, ) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 12264d5e..e2334e40 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -179,9 +179,9 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # After tma_partition, tBgK/tVgV have 8 modes. - # Mode 4 is the GMEM tile iteration axis (size = n_kv_tiles). - # Do NOT pre-slice — index all 8 modes explicitly in cute.copy. + # After tma_partition, tBgK/tVgV have 4 modes: (((64,128),1),?,?,?). + # Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles). + # Do NOT pre-slice — index all 4 modes explicitly in cute.copy. # tAgQ is fine with 4-mode slice (Q has only 1 tile). tAgQ = tAgQ[(None,0,None,0)] @@ -221,8 +221,8 @@ class FmhaV3StageCMulti: kvp.reset(); pk = kvp.try_acquire() for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[(None, None, kt, 0)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, 0, kt, 0)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) pk = cutlass.Boolean(1) kvp.tail()