DOCUMENT: TMA 8-mode indexing — the bug that cost us a full day. README + inline comments.

This commit is contained in:
2026-05-22 21:28:58 +00:00
parent 3a4524c318
commit dbd77f2bc4
2 changed files with 73 additions and 1 deletions

View File

@@ -1,5 +1,67 @@
# DSV4 Inference Kernel
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
After `cpasync.tma_partition()`, the output GMEM tensor has **8 modes**, NOT 4:
```
tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
0 1 2 3 4 5 6 7
```
**Mode 4 (0-indexed) is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
```python
# ❌❌❌ THIS ONLY ADDRESSES 4 OF THE 8 MODES ❌❌❌
# Mode 4 (the GMEM tile dim) is NEVER touched by this slice.
# It stays fixed at 0. The TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Only 4 modes addressed!
# The copy "works" but kv_coord is meaningless — it indexes mode 1,
# which has size 1, so every coordinate maps to the same TMA descriptor.
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord is ignored!
```
### THE RIGHT WAY (what actually works):
```python
# ✅ Keep ALL 8 modes. Do NOT pre-slice the TMA GMEM tensor.
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
# ✅ Index mode 4 (the GMEM tile dim) in the copy call
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
# ^^ THIS IS MODE 4 — THE GMEM TILE DIM
```
### WHY THIS IS SO INSIDIOUS
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently drops modes 4-7.
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 4 is size 1, so the bug is invisible.
3. **All multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
### THE LESSON
**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors with fewer dimensions than they actually have.**
```python
# After tma_partition, ALWAYS verify the shape:
print(f"tBgK shape: {cute.shape(tBgK)}") # Should show 8 modes
print(f"n_kv_tiles at mode 4: {cute.size(tBgK, mode=[4])}")
# Use full indexing in cute.copy, not a pre-sliced 2D view
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
```
**IF YOU SKIP THIS AND PRE-SLICE TO 4 MODES, MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
---
## Architecture
DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.

View File

@@ -138,7 +138,15 @@ class FmhaV3Diag:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,None,None,None,None,None,None)]; tVgV = tVgV[(None,None,None,None,None,None,None,None)]
# =====================================================================
# ⚠️⚠️⚠️ CRITICAL: TMA PARTITION TENSOR MODE ORDERING ⚠️⚠️⚠️
# After tma_partition, tBgK/tVgV have 8 modes: (1,1,1,1,n_kv_tiles,1,1,1)
# Mode 4 is the GMEM tile dimension. DO NOT pre-slice to fewer modes!
# See README.md for details.
# =====================================================================
tAgQ = tAgQ[(None,0,None,0)]
tBgK = tBgK[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice!
tVgV = tVgV[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice!
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -171,6 +179,8 @@ class FmhaV3Diag:
kv_coord = Int32(0 + 0)
for kt in cutlass.range(self.n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
# ⚠️ CRITICAL: kv_coord indexes MODE 4 of 8-mode tBgK/tVgV.
# Using (None, kv_coord) on a pre-sliced 4-mode tensor SILENTLY BREAKS multi-tile!
cute.copy(tma_k, tBgK[None, None, None, None, kv_coord, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[None, None, None, None, kv_coord, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord += 1