DOCUMENT: TMA 8-mode indexing — the bug that cost us a full day. README + inline comments.
This commit is contained in:
62
README.md
62
README.md
@@ -1,5 +1,67 @@
|
||||
# DSV4 Inference Kernel
|
||||
|
||||
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
|
||||
|
||||
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
|
||||
|
||||
After `cpasync.tma_partition()`, the output GMEM tensor has **8 modes**, NOT 4:
|
||||
|
||||
```
|
||||
tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
|
||||
0 1 2 3 4 5 6 7
|
||||
```
|
||||
|
||||
**Mode 4 (0-indexed) is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
|
||||
|
||||
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
|
||||
|
||||
```python
|
||||
# ❌❌❌ THIS ONLY ADDRESSES 4 OF THE 8 MODES ❌❌❌
|
||||
# Mode 4 (the GMEM tile dim) is NEVER touched by this slice.
|
||||
# It stays fixed at 0. The TMA ALWAYS reads from tile 0.
|
||||
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Only 4 modes addressed!
|
||||
|
||||
# The copy "works" but kv_coord is meaningless — it indexes mode 1,
|
||||
# which has size 1, so every coordinate maps to the same TMA descriptor.
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord is ignored!
|
||||
```
|
||||
|
||||
### THE RIGHT WAY (what actually works):
|
||||
|
||||
```python
|
||||
# ✅ Keep ALL 8 modes. Do NOT pre-slice the TMA GMEM tensor.
|
||||
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
|
||||
|
||||
# ✅ Index mode 4 (the GMEM tile dim) in the copy call
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
|
||||
# ^^ THIS IS MODE 4 — THE GMEM TILE DIM
|
||||
```
|
||||
|
||||
### WHY THIS IS SO INSIDIOUS
|
||||
|
||||
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently drops modes 4-7.
|
||||
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 4 is size 1, so the bug is invisible.
|
||||
3. **All multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
|
||||
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
|
||||
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
|
||||
|
||||
### THE LESSON
|
||||
|
||||
**ALWAYS print `cute.shape(tBgK)` after `tma_partition`. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors with fewer dimensions than they actually have.**
|
||||
|
||||
```python
|
||||
# After tma_partition, ALWAYS verify the shape:
|
||||
print(f"tBgK shape: {cute.shape(tBgK)}") # Should show 8 modes
|
||||
print(f"n_kv_tiles at mode 4: {cute.size(tBgK, mode=[4])}")
|
||||
|
||||
# Use full indexing in cute.copy, not a pre-sliced 2D view
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
|
||||
```
|
||||
|
||||
**IF YOU SKIP THIS AND PRE-SLICE TO 4 MODES, MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m′=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
|
||||
|
||||
@@ -138,7 +138,15 @@ class FmhaV3Diag:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,None,None,None,None,None,None)]; tVgV = tVgV[(None,None,None,None,None,None,None,None)]
|
||||
# =====================================================================
|
||||
# ⚠️⚠️⚠️ CRITICAL: TMA PARTITION TENSOR MODE ORDERING ⚠️⚠️⚠️
|
||||
# After tma_partition, tBgK/tVgV have 8 modes: (1,1,1,1,n_kv_tiles,1,1,1)
|
||||
# Mode 4 is the GMEM tile dimension. DO NOT pre-slice to fewer modes!
|
||||
# See README.md for details.
|
||||
# =====================================================================
|
||||
tAgQ = tAgQ[(None,0,None,0)]
|
||||
tBgK = tBgK[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice!
|
||||
tVgV = tVgV[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice!
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -171,6 +179,8 @@ class FmhaV3Diag:
|
||||
kv_coord = Int32(0 + 0)
|
||||
for kt in cutlass.range(self.n_kv_tiles, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
# ⚠️ CRITICAL: kv_coord indexes MODE 4 of 8-mode tBgK/tVgV.
|
||||
# Using (None, kv_coord) on a pre-sliced 4-mode tensor SILENTLY BREAKS multi-tile!
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kv_coord, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[None, None, None, None, kv_coord, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kv_coord += 1
|
||||
|
||||
Reference in New Issue
Block a user