From dbd77f2bc441ac3b41f17a362e2f8983499d8cf0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:28:58 +0000 Subject: [PATCH] =?UTF-8?q?DOCUMENT:=20TMA=208-mode=20indexing=20=E2=80=94?= =?UTF-8?q?=20the=20bug=20that=20cost=20us=20a=20full=20day.=20README=20+?= =?UTF-8?q?=20inline=20comments.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 62 +++++++++++++++++++++++++++++++++ tests/unit/test_fmha_v3_diag.py | 12 ++++++- 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bee728f1..81ca4e65 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,67 @@ # DSV4 Inference Kernel +## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️ + +**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.** + +After `cpasync.tma_partition()`, the output GMEM tensor has **8 modes**, NOT 4: + +``` +tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1) + 0 1 2 3 4 5 6 7 +``` + +**Mode 4 (0-indexed) is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles. + +### THE WRONG WAY (what we did — silently loads from tile 0 forever): + +```python +# ❌❌❌ THIS ONLY ADDRESSES 4 OF THE 8 MODES ❌❌❌ +# Mode 4 (the GMEM tile dim) is NEVER touched by this slice. +# It stays fixed at 0. The TMA ALWAYS reads from tile 0. +tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Only 4 modes addressed! + +# The copy "works" but kv_coord is meaningless — it indexes mode 1, +# which has size 1, so every coordinate maps to the same TMA descriptor. +cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord is ignored! +``` + +### THE RIGHT WAY (what actually works): + +```python +# ✅ Keep ALL 8 modes. Do NOT pre-slice the TMA GMEM tensor. +tBgK = tBgK[(None, None, None, None, None, None, None, None)] + +# ✅ Index mode 4 (the GMEM tile dim) in the copy call +cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...) +# ^^ THIS IS MODE 4 — THE GMEM TILE DIM +``` + +### WHY THIS IS SO INSIDIOUS + +1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently drops modes 4-7. +2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 4 is size 1, so the bug is invisible. +3. **All multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN. +4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone. +5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode. + +### THE LESSON + +**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors with fewer dimensions than they actually have.** + +```python +# After tma_partition, ALWAYS verify the shape: +print(f"tBgK shape: {cute.shape(tBgK)}") # Should show 8 modes +print(f"n_kv_tiles at mode 4: {cute.size(tBgK, mode=[4])}") + +# Use full indexing in cute.copy, not a pre-sliced 2D view +cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...) +``` + +**IF YOU SKIP THIS AND PRE-SLICE TO 4 MODES, MULTI-TILE TMA WILL BE SILENTLY BROKEN.** + +--- + ## Architecture DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m′=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different. diff --git a/tests/unit/test_fmha_v3_diag.py b/tests/unit/test_fmha_v3_diag.py index 5251492c..30336de7 100644 --- a/tests/unit/test_fmha_v3_diag.py +++ b/tests/unit/test_fmha_v3_diag.py @@ -138,7 +138,15 @@ class FmhaV3Diag: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,None,None,None,None,None,None)]; tVgV = tVgV[(None,None,None,None,None,None,None,None)] + # ===================================================================== + # ⚠️⚠️⚠️ CRITICAL: TMA PARTITION TENSOR MODE ORDERING ⚠️⚠️⚠️ + # After tma_partition, tBgK/tVgV have 8 modes: (1,1,1,1,n_kv_tiles,1,1,1) + # Mode 4 is the GMEM tile dimension. DO NOT pre-slice to fewer modes! + # See README.md for details. + # ===================================================================== + tAgQ = tAgQ[(None,0,None,0)] + tBgK = tBgK[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice! + tVgV = tVgV[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice! tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -171,6 +179,8 @@ class FmhaV3Diag: kv_coord = Int32(0 + 0) for kt in cutlass.range(self.n_kv_tiles, unroll=1): kvh = kvp.acquire_and_advance(pk) + # ⚠️ CRITICAL: kv_coord indexes MODE 4 of 8-mode tBgK/tVgV. + # Using (None, kv_coord) on a pre-sliced 4-mode tensor SILENTLY BREAKS multi-tile! cute.copy(tma_k, tBgK[None, None, None, None, kv_coord, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) cute.copy(tma_v, tVgV[None, None, None, None, kv_coord, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) kv_coord += 1