2026-05-21 17:34:40 +00:00
# DSV4 Inference Kernel
2026-05-13 15:44:51 +00:00
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
2026-05-22 21:28:58 +00:00
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
After `cpasync.tma_partition()` , the output GMEM tensor has **4 modes ** (verified on B200):
2026-05-22 21:28:58 +00:00
```
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
tBgK shape: (((64, 128), 1), ?, KV_tiles, ?)
mode 0 1 2 3
2026-05-22 21:28:58 +00:00
```
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
2026-05-22 21:28:58 +00:00
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
```python
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
# ❌❌❌ (None,None,0,0) KEEPS MODES 0,1 FREE, SETS MODE 2 TO 0 ❌❌❌
# Mode 2 (the KV tile dim) gets collapsed to coordinate 0.
# TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
# The copy "works" but kv_coord indexes mode 1 (inner GEMM K, not KV tiles).
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord indexes wrong mode!
2026-05-22 21:28:58 +00:00
```
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
### THE RIGHT WAY (verified on B200 at n=128 and n=256):
2026-05-22 21:28:58 +00:00
```python
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
# ✅ (None,0,None,0) keeps modes 0 and 2 free → 2D tensor
# Mode 2 (KV tiles) survives as the second mode.
tBgK = tBgK[(None, 0, None, 0)]
# ✅ [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles)
cute.copy(tma_k, tBgK[None, kt], ...)
# ^^ THIS IS THE KV TILE DIM
```
2026-05-22 21:28:58 +00:00
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
**Verified shapes on B200 (May 22, n=256, inside @cute .kernel):**
```
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
2026-05-22 21:28:58 +00:00
```
### WHY THIS IS SO INSIDIOUS
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
1. **No error, no warning. ** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0.
2. **Single-tile (n=128) works perfectly. ** With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
3. **Multi-tile tests produce "reasonable" output. ** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
2026-05-22 21:28:58 +00:00
4. **The strides are all 0. ** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
5. * * `cute.printf` shows `kv_coord=0` .** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
6. **The 8-mode theory was wrong. ** We assumed tma_partition produced 8 TMA coordinate dimensions. It produces 4. The 8-None no-op slice fails with "weakly congruent" at JIT compile.
2026-05-22 21:28:58 +00:00
### THE LESSON
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
**PRINT THE SHAPES. ALWAYS.** Run `print(f"tBgK: shape={cute.shape(tBgK)}")` inside `@cute.kernel` at trace time. The shapes are your ground truth. Reasoning about mode counts without evidence is how we wasted a day.
FIX: 8-None no-op pre-slice opens full TMA coordinate space (8 dims)
The tma_partition output has 8 TMA coordinate dimensions, not 4.
The Python-visible shape shows 4 modes, but the TMA descriptor uses
8 coordinates. Without the 8-None no-op pre-slice, modes 4-7 are
collapsed and the GMEM tile axis (mode 4) is pinned to 0.
Pattern that works (confirmed on B200 at n=256 in diag test):
tBgK = tBgK[(None,None,None,None,None,None,None,None)] # open 8D
cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...)
The old 4-mode indexing tBgK[(None,None,kt,0)] fails with
'rank mismatch: got 2 and 1' because slicing a 4-mode tensor
produces wrong rank for the TMA coordinate space.
Matches working diag test test_fmha_v3_diag.py exactly.
2026-05-22 23:18:40 +00:00
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
**The correct pre-slice depends on which mode is the GMEM tile iteration axis.** For our `local_tile` + `partition_B` + `group_modes(0,3)` pattern, mode 2 is the KV tile axis. `(None,0,None,0)` keeps it free. `(None,None,0,0)` collapses it to 0.
2026-05-22 21:28:58 +00:00
```python
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
# ALWAYS verify the shape at trace time:
print(f"tBgK shape: {cute.shape(tBgK)}") # 4 modes
print(f"tBgK after slice: {cute.shape(tBgK[(None,0,None,0)])}") # 2 modes
2026-05-22 21:28:58 +00:00
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
# Then index the 2D tensor:
cute.copy(tma_k, tBgK[None, kt], ...)
2026-05-22 21:28:58 +00:00
```
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
**IF YOU USE (None,None,0,0) INSTEAD OF (None,0,None,0), MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
2026-05-22 21:28:58 +00:00
---
2026-05-21 17:40:25 +00:00
## Architecture
DSV4 is **not MLA ** . It uses **CSA (Compressed Sparse Attention, m=4) ** and **HCA (Heavily Compressed Attention, m′ =128) ** . KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
```
DSV4 inference pipeline — component status
==========================================
Legend:
[✓] built and tested
[~] partial — reference or seam exists, native pending
[✗] to build
┌────────────────────────────────────┐
│ [✗] Embedding + mHC init │
│ token embed + n_hc=4 streams │
└────────────────┬───────────────────┘
│
▼
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 0– 1 of Pro, alternating CSA / HCA after │
│ │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
│ │ [~] Token compressor CSA m=4 / HCA m′ =128 │ │
│ │ [✗] Indexer + top-k CSA only, FP4 QK │ │
│ │ [~] FMHA core QK → online softmax → PV │ │
│ │ + SWA branch + sink merge │ │
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
│ └────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Pre-FFN norm RMSNorm │ │
│ │ [✗] Router sqrt(softplus) + topk + hash │ │
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
│ └────────────────────────────────────────────────────────────────┘ │
└──────────────────────────────────┬───────────────────────────────────┘
│
▼
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘
┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management │
│ • state cache: SWA window + uncompressed tail per layer │
│ • classical paged cache: lcm(m, m′ ) = 128 tokens per block │
│ • heterogeneous layout per layer │
└──────────────────────────────────────────────────────────────────────┘
Summary
-------
Built [✓] : 6 — mHC × 2, Q/KV proj, output proj, routed MoE,
shared expert
Partial [~] : 4 — norms+RoPE, token compressor, FMHA core,
pre-FFN norm
To build [✗] : 8 — embedding+init, indexer+top-k, router,
final norm, LM head, MTP, sampler, KV cache
```
---
2026-05-22 16:32:31 +00:00
## Status (May 22, 2026 — 16:30 UTC)
2026-05-13 15:44:51 +00:00
2026-05-21 15:43:01 +00:00
| Stage | Status | Description |
|-------|--------|-------------|
| A | ✅ COMPLETE | Q@K ^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
| C | ⚠️ MULTI-TILE TMA FIXED | n=128 cos 0.999998 ✅. TMA fix: n=256 loads 2 tiles. Pipeline cycling needed for n≥384. O rescale needed. |
2026-05-22 16:32:31 +00:00
| C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. |
2026-05-21 15:43:01 +00:00
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
2026-05-21 17:34:40 +00:00
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
2026-05-19 15:19:55 +00:00
2026-05-21 15:43:01 +00:00
---
2026-05-20 20:26:25 +00:00
2026-05-21 17:34:40 +00:00
## Package Structure
2026-05-21 15:43:01 +00:00
```
2026-05-21 17:34:40 +00:00
dsv4/
├── kernels/ Pure GPU code (CuTeDSL @cute .jit, .cu files)
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│ ├── attention/ FMHA kernel (stub — extraction is Stage E)
│ ├── compressor/ CSA/HCA token-level compressor
│ ├── decode/ Decode-time attention (sparse, SWA — future)
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/ PyTorch ↔ kernel bridges
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
│ ├── custom_ops.py torch.library.custom_op registrations
│ ├── decode_sparse.py native_sparse_decode dispatcher
│ ├── decode_swa.py native_swa_decode dispatcher
│ ├── rope.py Forward + inverse RoPE
│ └── topk.py Python wrapper for sparse_topk_metadata.cu
├── layers/ nn.Module-style components
│ ├── linear.py Nvfp4Linear
│ ├── grouped_linear.py Nvfp4GroupedLinear
│ ├── moe.py Nvfp4MoE
│ ├── shared_expert.py Nvfp4SharedExpert
│ ├── mhc.py mHCLayer
│ └── (stubs: attention, ffn, router, norm, embedding)
├── model/ Model assembly (stubs — Phase 1)
├── cache/ KV cache infra (stubs — Phase 3)
├── loader/ Checkpoint I/O (stubs — Phase 1)
└── reference/ Slow PyTorch oracles (never imported by production code)
├── attention.py RoPE, KV cache, causal attention, SWA
├── csa_attention.py CSA/HCA sparse attention
├── compressor.py Compressor PyTorch example
└── moe_pipeline.py MoE pipeline reference
2026-05-21 15:36:06 +00:00
```
2026-05-19 08:26:16 +00:00
2026-05-21 17:34:40 +00:00
**Mental model:** `kernels/` → `ops/` → `layers/` → `model/` (dependency flows left to right). `reference/` and `loader/` are sidecars.
2026-05-21 15:43:01 +00:00
---
2026-05-21 17:34:40 +00:00
## Active Test Files
2026-05-21 15:43:01 +00:00
2026-05-21 17:34:40 +00:00
### FMHA (Stages A/B/C) — in `tests/unit/`
2026-05-21 15:43:01 +00:00
2026-05-21 17:34:40 +00:00
| File | Stage | Status |
|------|-------|--------|
2026-05-22 09:39:15 +00:00
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
| `test_fmha_v3_stage_c_full.py` | C | ✅ Real online softmax + O normalization, cosine 0.993-0.996 |
| `test_fmha_v3_stage_c_min.py` | C | 🔨 Early 12-warp pipeline (broken pipeline state) |
2026-05-21 17:34:40 +00:00
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
| `test_qk_softmax.py` | A+B | ✅ QK + identity softmax, no PV |
2026-05-19 08:26:16 +00:00
2026-05-21 17:34:40 +00:00
### MoE / GEMM — in `tests/unit/`
2026-05-21 15:43:01 +00:00
2026-05-21 17:34:40 +00:00
| File | What |
|------|------|
| `test_cutedsl.py` | NVFP4 grouped GEMM kernel |
| `cudagraph_test.py` | Cudagraph capture + replay |
| `layertest.py` | Per-layer correctness |
| `test_custom_op.py` | torch.library custom ops |
| `test_compile_custom_op.py` | Compile + warmup |
| `test_fp4_roundtrip.py` | BF16 → NVFP4 → BF16 roundtrip |
| `test_interleave.py` | Gate/up weight interleaving |
| `test_interleave_gemm.py` | Interleaved GEMM correctness |
| `test_fused_step1.py` | Fused SwiGLU GEMM |
2026-05-21 15:43:01 +00:00
2026-05-21 17:34:40 +00:00
### Archived Tests
2026-05-21 15:43:01 +00:00
2026-05-21 17:34:40 +00:00
`tests/archive/` contains ~190 debug files from Stages A/B. Not maintained. Can be deleted.
2026-05-21 15:43:01 +00:00
---
2026-05-22 17:09:53 +00:00
## Test Harness
Scripts in `tests/` for running tests on the B200 (`root@45.76.247.107` ):
### `run_test.sh` — Run a test in a screen session
```bash
# On the B200:
cd /root/dsv4-nvfp4-workspace/kernel
bash tests/run_test.sh tests/unit/test_fmha_v3.py
```
What it does:
1. Kills any existing `kernel-test` screen and **SIGKILLs all child processes ** (handles deadlocked GPU procs that ignore SIGHUP)
2. Deletes the old log file
3. Starts a new `screen -dmS kernel-test` running the test
4. Logs output to `/tmp/kernel-test.log`
5. Verifies the screen started
### `check_log.sh` — Check test progress
```bash
bash tests/check_log.sh
```
Shows the log contents and whether the screen is still running.
### Local → B200 workflow
```bash
# 1. Edit locally, commit, push
cd ~/dev/nvfp4-megamoe-kernel
git add -A && git commit -m "my change" && git push
# 2. SSH to B200, pull, run
ssh root@45 .76.247.107
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py
# 3. Check results
bash tests/check_log.sh
```
2026-05-22 17:41:23 +00:00
### `fire_b200_test` — One-command local test runner
Lives in `~/.openclaw/workspace/fire_b200_test` (NOT in the repo — project-specific tooling).
```bash
# From your local machine, one command to push, run, and get results:
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py
```
What it does:
1. Auto-commits and pushes any local changes
2. SSH to B200, pulls, starts `run_test.sh` in a screen
3. Polls every 15s until the screen exits
4. Dumps the full test log to your terminal
**This is strictly for the DSV4 NVFP4 kernel project.** It hardcodes the B200 IP, repo paths, and git remote.
2026-05-22 17:09:53 +00:00
---
2026-05-22 22:57:53 +00:00
## Stage C: Online Softmax — Multi-Tile In Progress
2026-05-21 15:43:01 +00:00
2026-05-21 17:34:40 +00:00
### What We Have
2026-05-21 15:43:01 +00:00
2026-05-22 22:57:53 +00:00
**Working real softmax** for single KV tile (n=128): cosine 0.999998.
**Multi-tile TMA indexing fixed** (n=256 cosine 0.9956) — was a layout bug, NOT a JIT bug.
**Remaining:** O rescale between tiles, pipeline state cycling for n≥384, correction warps.
2026-05-22 16:32:31 +00:00
2026-05-22 22:57:53 +00:00
### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug)
2026-05-22 16:32:31 +00:00
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
After `cpasync.tma_partition()` , the output GMEM tensor has **4 modes ** : `(((64,128),1), ?, KV_tiles, ?)` .
2026-05-22 16:32:31 +00:00
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
**Mode 2 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` kept modes 0,1 free and set mode 2 to 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
2026-05-22 22:57:53 +00:00
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
**The fix:** `(None,0,None,0)` keeps modes 0,2 free, then `[None, kt]` indexes KV tiles:
2026-05-22 22:57:53 +00:00
```python
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
tBgK = tBgK[(None, 0, None, 0)]
cute.copy(tma_k, tBgK[None, kt], ...)
2026-05-22 22:57:53 +00:00
```
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
**Results after TMA fix (verified on B200, May 22):**
2026-05-22 22:57:53 +00:00
- n=128: cos 0.999998 ✅
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
- n=256: cos 0.71 (TMA loads 2 tiles correctly, needs O rescale for 0.9999)
- n=512/1024: output identical to n=256 — pipeline not cycling past kv_stage=2
**Verified tensor shapes (diag prints inside @cute .kernel on B200, n=256):**
```
Before (None,0,None,0) pre-slice:
tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
tVgV: (((64,128),1), 1, 1, 1) — 4 modes
After (None,0,None,0) pre-slice:
tAgQ: (((64,128),1), Int32(?)) — 2 modes, mode 1 = KV tiles
tBgK: (((64,128),1), Int32(?)) — 2 modes, mode 1 = KV tiles
tVgV: (((64,128),1), 1) — 2 modes, mode 1 = 1 (static)
```
2026-05-22 16:32:31 +00:00
2026-05-22 22:57:53 +00:00
### Remaining for Multi-Tile
2026-05-22 16:32:31 +00:00
2026-05-22 22:57:53 +00:00
1. O rescale between tiles: `O *= exp2(old_max - new_max)` — needed for n=256+ to hit 0.9999
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
2. Pipeline state cycling for n≥384 (3+ tiles with 2 pipeline stages) — output identical for all n>256, meaning only 2 KV tiles are loaded
2026-05-22 22:57:53 +00:00
3. Correction warps for production (separate softmax/correction/epilogue)
4. 12-warp layout
2026-05-22 16:32:31 +00:00
### Files
| File | Status | Notes |
|------|--------|-------|
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
| `fmha_v3_stage_c_example10.py` | 🔨 CURRENT | (None,0,None,0) TMA, combined K+V pipeline, O rescale, final normalize |
2026-05-22 22:57:53 +00:00
| `test_fmha_v3_stage_c_full.py` | OK n=128 | Working real softmax + O normalization |
2026-05-22 16:32:31 +00:00
| `fmha_v3_stage_c_example1.py` | BROKEN multi-tile | First fix attempt, TMA still loads tile 0 |
| `fmha_v3_stage_c_example2.py` | DEADLOCK | Combined K+V barrier, compiles but deadlocks |
| `test_fmha_v3_stage_c2.py` | DEADLOCK | 12-warp pipeline, compiles but deadlocks |
| `test_fmha_v3_12w.py` | OK n=128 only | Identity softmax baseline |
2026-05-21 15:43:01 +00:00
2026-05-22 09:39:15 +00:00
### Current Architecture (6-warp)
2026-05-21 15:43:01 +00:00
2026-05-22 16:32:31 +00:00
Warps 0-3: Softmax + Epilogue
Warp 4: MMA (QK, PV)
2026-05-22 09:39:15 +00:00
Warp 5: TMA (Q/K/V load)
feat: CUTLASS NVFP4 mega_moe kernel — slot-based L1/L2, source-first SF remap
Major changes from initial TileLang prototype:
Kernel:
- CUTLASS NVFP4 block-scaled GEMM (SM100 Blackwell, OpClassBlockScaledTensorOp)
- Slot-based dispatch: L1 GEMM → SiLU+Mul per-slot → L2 GEMM → index_add scatter
- 1D slot_expert_ids passed to both L1 and L2 (no 2D topk_ids rebuild)
- slot_token gathered in cutlass_grouped_nvfp4_gemm when provided
SF Remap (source-first):
- Iterates logical (m, k_sf) source grid, uses layout_sf(make_coord(m, k_sf))
for CUTLASS dest index — no idx2crd/flatten coordinate extraction
- 2D kernel launch: dim3 block(32,8), grid over (K_sf, MN)
- Uses cute::cosize() for physical allocation size (not cute::size)
- SFA: (MN, K_sf) row-major; SFB: (K_sf, MN) row-major (col-major)
Weight transform:
- UE4M3 unpack with bit reinterpret (not value cast)
- Global scale folding (weight_scale_2) for gate/up split
- clamp(0,448) → float8_e4m3fn, transpose (N,K)→(K,N) for CUTLASS
No prepack cache:
- SFB remapped per-call inside CUTLASS (~µs, not the bottleneck)
- See README for why prepack cache must never return (OOM, CUDA graphs,
M-dependent layout, cross-layer collisions)
Stage activation:
- Nearest-neighbor E2M1 quantization (no clamp, no uniform steps)
- Per-tensor global scale → alpha for L2 GEMM
Bug fixes:
- _fold_global_scale: removed broken logical_widths branch
- unpack_ue4m3_u32: int32 for CUDA bitwise, view not to, ND support
- Correct expert param mapping for NVFP4 checkpoint
- SiLU applied per-slot (not after summing expert paths)
2026-05-15 11:38:18 +00:00
2026-05-22 09:39:15 +00:00
### Target Architecture (12-warp, production)
2026-05-21 04:10:07 +00:00
2026-05-22 16:32:31 +00:00
Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty
2026-05-21 17:34:40 +00:00
2026-05-22 09:39:15 +00:00
### CuTeDSL Constraints (hard-won)
2026-05-21 06:46:02 +00:00
2026-05-22 16:32:31 +00:00
1. `vectorize=True` loops: ONLY load/store/print
2. `.reduce(cute.ReductionOp.MAX)` : reduces ENTIRE C-fragment to scalar — global max, not per-row
3. `cute.arch.fmax` : impure for vectorizer — use plain `range()` loop
🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.
THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.
VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
n=128: cos 0.999998 ✅ PASS
n=256: cos 0.71 (TMA loads 2 tiles, needs O rescale for 0.9999)
n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)
example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)
LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.
Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00
4. `tBgK` /`tVgV` have 4 modes after tma_partition — (None,0,None,0) keeps mode 2 (KV tiles) free, [None, kt] indexes it
2026-05-22 16:32:31 +00:00
5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0
6. `softmax_done_bar` NamedBarrier is reusable across tiles
2026-05-22 09:39:15 +00:00
### Remaining for C' (Production Stage C)
2026-05-22 16:32:31 +00:00
1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2
2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync)
3. Cross-warp reduction for row_max and row_sum
4. Correction warps for multi-tile KV (online O rescale in TMEM)
5. 12-warp layout with separate softmax/correction/epilogue warps
2026-05-22 09:39:15 +00:00
### TMEM Layout
2026-05-22 16:32:31 +00:00
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32)
2026-05-21 06:46:02 +00:00
2026-05-21 17:34:40 +00:00
---
2026-05-20 03:30:35 +00:00
2026-05-21 17:34:40 +00:00
## Key Lessons
2026-05-21 15:36:06 +00:00
2026-05-21 17:34:40 +00:00
1. **NEVER use `find_tmem_tensor_col_offset()` as TMEM placement. ** It returns footprint size, not a safe offset.
2. **FMHA never trusts DLPack tensor layouts. ** Reconstruct V as (hd, s_k) MN-major inside CuTe.
3. **TMEM allocation must be power of 2. **
4. **Square hides bugs. ** (128,128) worked for every wrong approach. Always test non-square.
5. **St32x32bOp MUST use Float32 ** , NOT BFloat16. BFloat16 causes illegal memory access.
6. **First PV ACCUMULATE=False. ** Otherwise adds uninitialized TMEM to output.
7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment. ** Two aliases, same TMEM.
8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout). ** Do not skip this.
2026-05-20 03:30:35 +00:00
2026-05-21 17:34:40 +00:00
---
2026-05-20 04:39:57 +00:00
2026-05-21 17:34:40 +00:00
## Environment
2026-05-20 20:26:25 +00:00
2026-05-21 17:34:40 +00:00
- Server: root@45 .76.247.107 (B200, 180 GiB HBM3e per GPU)
- venv: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
- PYTHONPATH: `/root/dsv4-nvfp4-workspace/kernel`
- Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
- vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
- CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`