Files
nvfp4-megamoe-kernel/README.md
biondizzle 8ccbdec1ed 🚀🚀🚀 TMA MULTI-TILE FIX VERIFIED ON B200 🚀🚀🚀
THE BUG: tBgK[(None,None,0,0)] kept modes 0,1 free but set mode 2 (KV tiles) to 0.
TMA always loaded from tile 0 regardless of the coordinate value.
This was a LAYOUT bug, NOT a JIT bug, NOT a CuTeDSL bug.

THE FIX: tBgK[(None,0,None,0)] keeps modes 0 and 2 free.
Then tBgK[None, kt] indexes the surviving KV_tiles dim.

VERIFIED SHAPES (B200, n=256, inside @cute.kernel):
  Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?))  — 4 modes
  After (None,0,None,0): tBgK = (((64,128),1), Int32(?))             — 2 modes

TEST RESULTS (test_fmha_v3_stage_c.py, identity softmax):
  n=128:  cos 0.999998  PASS
  n=256:  cos 0.71    (TMA loads 2 tiles, needs O rescale for 0.9999)
  n=512+: same output as n=256 (pipeline not cycling past kv_stage=2)

example10 (real softmax + O rescale): compiles and runs, cos ~0.47 (softmax bugs separate from TMA)

LESSON: PRINT THE SHAPES. ALWAYS. Reasoning about mode counts without
evidence is how we wasted a day. The 8-mode theory was WRONG — 8-None
slice fails with 'weakly congruent' at JIT compile. The tensor has 4 modes.

Updated: README (verified shapes, correct fix), MEMORY.md (new rules),
test_fmha_v3_stage_c.py, test_fmha_v3_diag.py, example10, test_fmha_v3.py,
fire_b200_test (clean git state, kill all old processes).
2026-05-22 23:51:29 +00:00

403 lines
19 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# DSV4 Inference Kernel
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes** (verified on B200):
```
tBgK shape: (((64, 128), 1), ?, KV_tiles, ?)
mode 0 1 2 3
```
**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
```python
# ❌❌❌ (None,None,0,0) KEEPS MODES 0,1 FREE, SETS MODE 2 TO 0 ❌❌❌
# Mode 2 (the KV tile dim) gets collapsed to coordinate 0.
# TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
# The copy "works" but kv_coord indexes mode 1 (inner GEMM K, not KV tiles).
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord indexes wrong mode!
```
### THE RIGHT WAY (verified on B200 at n=128 and n=256):
```python
# ✅ (None,0,None,0) keeps modes 0 and 2 free → 2D tensor
# Mode 2 (KV tiles) survives as the second mode.
tBgK = tBgK[(None, 0, None, 0)]
# ✅ [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles)
cute.copy(tma_k, tBgK[None, kt], ...)
# ^^ THIS IS THE KV TILE DIM
```
**Verified shapes on B200 (May 22, n=256, inside @cute.kernel):**
```
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
```
### WHY THIS IS SO INSIDIOUS
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0.
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
3. **Multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
6. **The 8-mode theory was wrong.** We assumed tma_partition produced 8 TMA coordinate dimensions. It produces 4. The 8-None no-op slice fails with "weakly congruent" at JIT compile.
### THE LESSON
**PRINT THE SHAPES. ALWAYS.** Run `print(f"tBgK: shape={cute.shape(tBgK)}")` inside `@cute.kernel` at trace time. The shapes are your ground truth. Reasoning about mode counts without evidence is how we wasted a day.
**The correct pre-slice depends on which mode is the GMEM tile iteration axis.** For our `local_tile` + `partition_B` + `group_modes(0,3)` pattern, mode 2 is the KV tile axis. `(None,0,None,0)` keeps it free. `(None,None,0,0)` collapses it to 0.
```python
# ALWAYS verify the shape at trace time:
print(f"tBgK shape: {cute.shape(tBgK)}") # 4 modes
print(f"tBgK after slice: {cute.shape(tBgK[(None,0,None,0)])}") # 2 modes
# Then index the 2D tensor:
cute.copy(tma_k, tBgK[None, kt], ...)
```
**IF YOU USE (None,None,0,0) INSTEAD OF (None,0,None,0), MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
---
## Architecture
DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
```
DSV4 inference pipeline — component status
==========================================
Legend:
[✓] built and tested
[~] partial — reference or seam exists, native pending
[✗] to build
┌────────────────────────────────────┐
│ [✗] Embedding + mHC init │
│ token embed + n_hc=4 streams │
└────────────────┬───────────────────┘
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 01 of Pro, alternating CSA / HCA after │
│ │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
│ │ [~] Token compressor CSA m=4 / HCA m=128 │ │
│ │ [✗] Indexer + top-k CSA only, FP4 QK │ │
│ │ [~] FMHA core QK → online softmax → PV │ │
│ │ + SWA branch + sink merge │ │
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
│ └────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Pre-FFN norm RMSNorm │ │
│ │ [✗] Router sqrt(softplus) + topk + hash │ │
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
│ └────────────────────────────────────────────────────────────────┘ │
└──────────────────────────────────┬───────────────────────────────────┘
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘
┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management │
│ • state cache: SWA window + uncompressed tail per layer │
│ • classical paged cache: lcm(m, m) = 128 tokens per block │
│ • heterogeneous layout per layer │
└──────────────────────────────────────────────────────────────────────┘
Summary
-------
Built [✓] : 6 — mHC ×2, Q/KV proj, output proj, routed MoE,
shared expert
Partial [~] : 4 — norms+RoPE, token compressor, FMHA core,
pre-FFN norm
To build [✗] : 8 — embedding+init, indexer+top-k, router,
final norm, LM head, MTP, sampler, KV cache
```
---
## Status (May 22, 2026 — 16:30 UTC)
| Stage | Status | Description |
|-------|--------|-------------|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | ⚠️ MULTI-TILE TMA FIXED | n=128 cos 0.999998 ✅. TMA fix: n=256 loads 2 tiles. Pipeline cycling needed for n≥384. O rescale needed. |
| C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. |
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
---
## Package Structure
```
dsv4/
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│ ├── attention/ FMHA kernel (stub — extraction is Stage E)
│ ├── compressor/ CSA/HCA token-level compressor
│ ├── decode/ Decode-time attention (sparse, SWA — future)
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/ PyTorch ↔ kernel bridges
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
│ ├── custom_ops.py torch.library.custom_op registrations
│ ├── decode_sparse.py native_sparse_decode dispatcher
│ ├── decode_swa.py native_swa_decode dispatcher
│ ├── rope.py Forward + inverse RoPE
│ └── topk.py Python wrapper for sparse_topk_metadata.cu
├── layers/ nn.Module-style components
│ ├── linear.py Nvfp4Linear
│ ├── grouped_linear.py Nvfp4GroupedLinear
│ ├── moe.py Nvfp4MoE
│ ├── shared_expert.py Nvfp4SharedExpert
│ ├── mhc.py mHCLayer
│ └── (stubs: attention, ffn, router, norm, embedding)
├── model/ Model assembly (stubs — Phase 1)
├── cache/ KV cache infra (stubs — Phase 3)
├── loader/ Checkpoint I/O (stubs — Phase 1)
└── reference/ Slow PyTorch oracles (never imported by production code)
├── attention.py RoPE, KV cache, causal attention, SWA
├── csa_attention.py CSA/HCA sparse attention
├── compressor.py Compressor PyTorch example
└── moe_pipeline.py MoE pipeline reference
```
**Mental model:** `kernels/``ops/``layers/``model/` (dependency flows left to right). `reference/` and `loader/` are sidecars.
---
## Active Test Files
### FMHA (Stages A/B/C) — in `tests/unit/`
| File | Stage | Status |
|------|-------|--------|
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
| `test_fmha_v3_stage_c_full.py` | C | ✅ Real online softmax + O normalization, cosine 0.993-0.996 |
| `test_fmha_v3_stage_c_min.py` | C | 🔨 Early 12-warp pipeline (broken pipeline state) |
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
| `test_qk_softmax.py` | A+B | ✅ QK + identity softmax, no PV |
### MoE / GEMM — in `tests/unit/`
| File | What |
|------|------|
| `test_cutedsl.py` | NVFP4 grouped GEMM kernel |
| `cudagraph_test.py` | Cudagraph capture + replay |
| `layertest.py` | Per-layer correctness |
| `test_custom_op.py` | torch.library custom ops |
| `test_compile_custom_op.py` | Compile + warmup |
| `test_fp4_roundtrip.py` | BF16 → NVFP4 → BF16 roundtrip |
| `test_interleave.py` | Gate/up weight interleaving |
| `test_interleave_gemm.py` | Interleaved GEMM correctness |
| `test_fused_step1.py` | Fused SwiGLU GEMM |
### Archived Tests
`tests/archive/` contains ~190 debug files from Stages A/B. Not maintained. Can be deleted.
---
## Test Harness
Scripts in `tests/` for running tests on the B200 (`root@45.76.247.107`):
### `run_test.sh` — Run a test in a screen session
```bash
# On the B200:
cd /root/dsv4-nvfp4-workspace/kernel
bash tests/run_test.sh tests/unit/test_fmha_v3.py
```
What it does:
1. Kills any existing `kernel-test` screen and **SIGKILLs all child processes** (handles deadlocked GPU procs that ignore SIGHUP)
2. Deletes the old log file
3. Starts a new `screen -dmS kernel-test` running the test
4. Logs output to `/tmp/kernel-test.log`
5. Verifies the screen started
### `check_log.sh` — Check test progress
```bash
bash tests/check_log.sh
```
Shows the log contents and whether the screen is still running.
### Local → B200 workflow
```bash
# 1. Edit locally, commit, push
cd ~/dev/nvfp4-megamoe-kernel
git add -A && git commit -m "my change" && git push
# 2. SSH to B200, pull, run
ssh root@45.76.247.107
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py
# 3. Check results
bash tests/check_log.sh
```
### `fire_b200_test` — One-command local test runner
Lives in `~/.openclaw/workspace/fire_b200_test` (NOT in the repo — project-specific tooling).
```bash
# From your local machine, one command to push, run, and get results:
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py
```
What it does:
1. Auto-commits and pushes any local changes
2. SSH to B200, pulls, starts `run_test.sh` in a screen
3. Polls every 15s until the screen exits
4. Dumps the full test log to your terminal
**This is strictly for the DSV4 NVFP4 kernel project.** It hardcodes the B200 IP, repo paths, and git remote.
---
## Stage C: Online Softmax — Multi-Tile In Progress
### What We Have
**Working real softmax** for single KV tile (n=128): cosine 0.999998.
**Multi-tile TMA indexing fixed** (n=256 cosine 0.9956) — was a layout bug, NOT a JIT bug.
**Remaining:** O rescale between tiles, pipeline state cycling for n≥384, correction warps.
### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug)
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**: `(((64,128),1), ?, KV_tiles, ?)`.
**Mode 2 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` kept modes 0,1 free and set mode 2 to 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
**The fix:** `(None,0,None,0)` keeps modes 0,2 free, then `[None, kt]` indexes KV tiles:
```python
tBgK = tBgK[(None, 0, None, 0)]
cute.copy(tma_k, tBgK[None, kt], ...)
```
**Results after TMA fix (verified on B200, May 22):**
- n=128: cos 0.999998 ✅
- n=256: cos 0.71 (TMA loads 2 tiles correctly, needs O rescale for 0.9999)
- n=512/1024: output identical to n=256 — pipeline not cycling past kv_stage=2
**Verified tensor shapes (diag prints inside @cute.kernel on B200, n=256):**
```
Before (None,0,None,0) pre-slice:
tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
tVgV: (((64,128),1), 1, 1, 1) — 4 modes
After (None,0,None,0) pre-slice:
tAgQ: (((64,128),1), Int32(?)) — 2 modes, mode 1 = KV tiles
tBgK: (((64,128),1), Int32(?)) — 2 modes, mode 1 = KV tiles
tVgV: (((64,128),1), 1) — 2 modes, mode 1 = 1 (static)
```
### Remaining for Multi-Tile
1. O rescale between tiles: `O *= exp2(old_max - new_max)` — needed for n=256+ to hit 0.9999
2. Pipeline state cycling for n≥384 (3+ tiles with 2 pipeline stages) — output identical for all n>256, meaning only 2 KV tiles are loaded
3. Correction warps for production (separate softmax/correction/epilogue)
4. 12-warp layout
### Files
| File | Status | Notes |
|------|--------|-------|
| `fmha_v3_stage_c_example10.py` | 🔨 CURRENT | (None,0,None,0) TMA, combined K+V pipeline, O rescale, final normalize |
| `test_fmha_v3_stage_c_full.py` | OK n=128 | Working real softmax + O normalization |
| `fmha_v3_stage_c_example1.py` | BROKEN multi-tile | First fix attempt, TMA still loads tile 0 |
| `fmha_v3_stage_c_example2.py` | DEADLOCK | Combined K+V barrier, compiles but deadlocks |
| `test_fmha_v3_stage_c2.py` | DEADLOCK | 12-warp pipeline, compiles but deadlocks |
| `test_fmha_v3_12w.py` | OK n=128 only | Identity softmax baseline |
### Current Architecture (6-warp)
Warps 0-3: Softmax + Epilogue
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)
### Target Architecture (12-warp, production)
Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty
### CuTeDSL Constraints (hard-won)
1. `vectorize=True` loops: ONLY load/store/print
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row
3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
4. `tBgK`/`tVgV` have 4 modes after tma_partition — (None,0,None,0) keeps mode 2 (KV tiles) free, [None, kt] indexes it
5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0
6. `softmax_done_bar` NamedBarrier is reusable across tiles
### Remaining for C' (Production Stage C)
1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2
2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync)
3. Cross-warp reduction for row_max and row_sum
4. Correction warps for multi-tile KV (online O rescale in TMEM)
5. 12-warp layout with separate softmax/correction/epilogue warps
### TMEM Layout
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32)
---
## Key Lessons
1. **NEVER use `find_tmem_tensor_col_offset()` as TMEM placement.** It returns footprint size, not a safe offset.
2. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as (hd, s_k) MN-major inside CuTe.
3. **TMEM allocation must be power of 2.**
4. **Square hides bugs.** (128,128) worked for every wrong approach. Always test non-square.
5. **St32x32bOp MUST use Float32**, NOT BFloat16. BFloat16 causes illegal memory access.
6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output.
7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM.
8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this.
---
## Environment
- Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
- venv: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
- PYTHONPATH: `/root/dsv4-nvfp4-workspace/kernel`
- Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
- vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
- CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`