Keep epilogue_tma_store for final output (proven path). Only fix the multi-KV-tile O rescale using paired atoms from epilogue_tmem_copy_and_partition. The paired atoms share addressing, making the TMEM->REGS->modify->TMEM cycle lossless. Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128) is completely unaffected — zero regression risk. Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred until we can address the MLIR compilation time issue.
14 KiB
ROADMAP
Living document. Current state, active blockers, priority order, and what to build next. Architecture and lessons live in README.md — this file is for "what now."
Last updated: 2026-05-26
Current status
Working
| Component | hd | n | cos | Status |
|---|---|---|---|---|
| FMHA TMEM-P | 64 | 128 | 0.999998 | ✅ |
| FMHA TMEM-P / SMEM-P | 128 | 128 | 0.999997 | ✅ |
| FMHA TMEM-P | 256 | 128 | 0.999998 | ✅ |
| FMHA multi-tile (Python KV merge) | 64 | up to 1024 | 0.999998 | ✅ Workaround |
| D3 SWA length mask (in-kernel) | 128 | 128 | 0.999996 | ✅ |
| D4 causal mask on SWA (in-kernel) | 128 | 128 | 0.999996 | ✅ |
| D5c sink merge single-tile | 64 | 128 | 0.999996 | ✅ |
| D5c sink merge multi-tile (Python KV merge) | 64 | 256 | 0.999996 | ✅ |
| Per-head multi-head launch | 64 | 128 | 0.999995 | ✅ n_h=1–128 |
| MoE fused SwiGLU (NVFP4) | — | — | matches ref | ✅ Clamping in kernel |
| Dense router (sqrt-softplus) | — | — | matches ref | ✅ |
| Hash router | — | — | matches ref | ✅ |
use_2cta_instrs conditional |
— | — | 1.7–1.9× speedup | ✅ M≥256 prefill |
| NVFP4 primitives | — | — | E4M3 SF / mxf4nvf4 / 16-elem | ✅ Verified |
Known blockers
| Blocker | Impact | Workaround | Fix path |
|---|---|---|---|
| D1.5 TMEM round-trip corruption | Hand-built atoms produce 3% error on NO-OP round-trip; blocks in-kernel multi-tile O rescale and in-kernel normalize | Emit un-normalized O + LSE; Python KV merge for s_k>128 | Priority 1: correction-epilog rewrite (sketch below) |
| hd=512 MLIR backend hang | Cannot compile hd=512 kernel (>3hr optimizer time, structurally correct) | Run hd=512 via head-packed M with hd≤256 chunks; ship without hd=512 if needed for D2 | Pre-compile cubin / raw CUTLASS C++ / report NVIDIA bug |
| D2 multi-CTA grid (flat_divide + epilogue_tma_store) | Per-head Python launch wastes 128 launches per decode step at Pro | Per-head launch (works, just slow) | Unblocked by correction-epilog rewrite (uses flat_divide + tma_partition like MoE does) |
Priority 1: Correction epilog rewrite (unblocks D1.5 + a chain of follow-ons)
Why this first: Every downstream item needs the kernel to have a register-level slot in the epilogue for modification. The current epilogue_tma_store path with hand-built atoms doesn't have one. The correction-epilog pattern does.
The pattern is already in the codebase — dsv4/kernels/gemm/fused_swiglu.py uses it for the MoE SwiGLU epilogue (lines 2021, 2064–2229). Library helpers, paired atoms, one-way TMEM → registers → SMEM → GMEM. SwiGLU + clamping math sits between the t2r and r2s copies. That's the exact slot FMHA needs.
What changes:
Replace the FMHA epilogue (dsv4/kernels/attention/fmha.py lines 549–597 — the epilogue_tma_store call) with:
-
Setup (run once per kernel, outside the kt loop):
tCtO_transformed = utils.gemm.sm100.transform_partitioned_tensor_layout(tOtO0)tCgC_transformed = utils.gemm.sm100.transform_partitioned_tensor_layout(tCgC)tiled_copy_t2r, tTR_tO_base, tTR_rO = utils.gemm.sm100.epilogue_tmem_copy_and_partition(...)tiled_copy_r2s, tRS_rC, tRS_sC = utils.gemm.sm100.epilogue_smem_copy_and_partition(...)- TMA partition for C via
flat_divide(tCgC_transformed, epi_tile)+cpasync.tma_partition(...)
-
Final epilogue (replaces the round-trip normalize):
- Subtile loop, in each subtile:
cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO)→ multiply byinv_row_sumin registers ifnormalize=True→ cast to BF16 →cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[...])→ TMA SMEM → GMEM.
- Subtile loop, in each subtile:
-
Per-kt O rescale (replaces the broken hand-built round-trip on lines 524–544):
- Inside the kt loop, when
kt > 0: same t2r → multiply byacc_scalein registers → store back via the paired atom (tiled_copy_t2r.retile_to_S()or equivalent — verify exact API on B200).
- Inside the kt loop, when
What unblocks:
- D1.5 issue 1 (round-trip corruption): gone.
- D1.5 issue 2 (per-kt rescale): gone.
- In-kernel multi-tile attention (single launch for s_k=1152, not 9).
- NVFP4-1.2 (fuse FP4 quant into FMHA output → wo_a path): the register slot is where amax + FP4 pack go.
- D2 multi-CTA grid:
flat_divide+tma_partitionpath is the same one MoE uses successfully. The flat_divide vs local_tile mismatch resolves.
Caveats to print and verify on B200:
- Exact CUTLASS helper API for the "store back to TMEM" direction (
retile_to_Sform vs separate helper vs same-base-tensor pattern). - Whether
transform_partitioned_tensor_layoutacceptstOtO0(TMEM iterator with offset) or needs a fresh tensor built attmem_ptr + self.tmem_o0_offsetwithtCtO_fake.layout. - Whether
tma_partitioninsideif warp_idx < self.mma_warp_idworks in this kernel's region tree. The MoE kernel does it; if FMHA hits "weakly congruent," hoist the partition call above the warp branch.
Done when:
- hd=64/128/256 regression cos ≥ 0.999998 holds with
normalize=Trueandnormalize=False. - New multi-tile s_k=256 test with
kt > 0rescale gives cos ≥ 0.999998 (not the current 0.997 Python-merge workaround, the real in-kernel rescale). - Existing Python KV merge tests continue to pass (
test_d15_multi_kv.py).
Priority 2: Stage E — Production extraction
D5 is complete. The kernel works. Wrap it in a proper interface.
| Step | What | Status |
|---|---|---|
| E1 | File placement: dsv4/kernels/attention/fmha.py |
✅ Done |
| E2 | Constructor signature (head_dim, num_query_heads, sliding_window, top_k, sink/causal flags, dtypes) |
⚠️ Partial — needs cleanup |
| E3 | Call signature: q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o, stream |
⚠️ Needs sink_bias / row_sums integration |
| E4 | Kernel cache + warmup, keyed on (head_dim, num_query_heads, top_k, n_comp, apply_sink_bias, is_causal, ...) |
TODO |
| E5 | torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",)) |
TODO |
| E6 | Reference parity test against FP32 oracle in dsv4/reference/attention.py |
TODO |
| E7 | Cleanup: delete debug test files, keep only tests/unit/test_fmha_kernel.py |
TODO |
Notably absent from the call signature: block_table, paged KV, inv_scale, FP8 dequant. All handled upstream by the indexer + gather kernel chain. FMHA sees a dense BF16 [T, top_k, head_dim] tile.
Priority 3: NVFP4-1.1 — Fuse FP4 quant into MoE SwiGLU epilogue
Independent of FMHA. Biggest bandwidth win in the codebase. Can run in parallel with Priority 1.
Current:
padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM
↓
quantize_activation_nvfp4 (separate kernel)
↓
padded_activated_fp4 → L2 GEMM
Target:
padded_x_fp4 → L1 GEMM → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM
The SwiGLU + clamp result already lives in registers at tRS_rC.store(acc_vec_bf16) (line 2207 of fused_swiglu.py). That's the slot for amax + FP4 pack.
Per-microblock amax (16 contiguous elements):
- shfl_xor butterfly reduction across the 4 threads that hold the 16 elements.
- FP8 E4M3 scale = amax / 6 (FP4 e2m1 max).
- Per-element FP4 pack: sign bit << 3 | (clamped val / scale).to(uint3). Two elements → one byte.
- 16 packed nibbles → 64-bit word → SMEM stage → TMA store.
- FP8 scale → separate scale-factor SMEM stage → TMA store to the L2 SFA buffer.
Subtlety: NVFP4 microblock = 16 elements. Port the same 16-element logic from dsv4/ops/quantize.py. Don't accidentally use the 32-element MXFP4 block.
Done when:
padded_activated_fp4andpadded_activated_x_sfscratch buffers go away.quantize_activation_nvfp4between L1 and L2 disappears.- L1 → L2 cosine matches reference (no regression from BF16 intermediate).
- L2 GEMM reads FP4 scales produced by L1 epilogue.
Priority 4: D2 multi-CTA grid
Currently per-head Python launch (works, cos 0.999995, but 128 launches per decode step at Pro).
Multi-CTA grid is unblocked by Priority 1 — the flat_divide + tma_partition path becomes available once the epilogue uses the MoE pattern.
Grid: (num_M_tiles, num_query_heads, batch) — at decode T=1: (1, 128, batch).
MQA K/V sharing: start with independent K/V loads per CTA (each CTA loads its own copy). At decode hd=512, K/V per CTA is ~128 KB; 128 CTAs × 128 KB = 16 MB, well within HBM bandwidth. Cluster-wide sharing via cluster_shape_mn=(1, num_query_heads, 1) is a future optimization once profiling shows it matters.
Q tensor layout: Option 1 — (batch, n_h, T, head_dim) with head as a TMA mode (matches CUTLASS reference and allows per-head LSE output). Picked over Option 2 (heads packed into M) because it generalizes better to GQA later.
Done when:
n_h=128, batch=4, T=1at hd=512 produces correct output with single launch.- Per-head LSE writes to correct
mLSE[batch, head, m_row]position.
Priority 5: NVFP4-1.2 — Fuse FP4 quant into FMHA output → wo_a path
Depends on Priority 1 (correction epilog gives the register slot).
Currently: FMHA emits BF16 → inverse RoPE produces BF16 → wo_a quantizes to FP4.
Target: register slot in FMHA epilogue does the divide-by-row_sum and inverse RoPE rotation and per-microblock amax + FP4 pack. wo_a reads FP4 directly.
Same pattern as Priority 3. Different home (FMHA epilogue, not MoE epilogue).
Priority 6: NVFP4-2 — FP4 KV pipeline depth in FMHA
Depends on Priority 1 being solid at BF16 KV first.
FP4 KV shrinks tiles ~4×; same SMEM budget supports more pipeline stages.
| KV dtype | Tile size (hd=512) | Stages that fit (192 KB budget) |
|---|---|---|
| BF16 | 128 KB | 2 |
| FP8 | 64 KB | 4 |
| FP4 | ~36 KB | 6 |
At 1M-context decode where KV reads dominate, deeper pipelines hide more TMA latency.
Implementation:
- TMA loads FP4 NoPE dims (packed
e2m1_x2) to SMEM slot 0. - TMA loads BF16 RoPE dims to SMEM slot 1.
- TMA loads FP8 scale factors to SMEM slot 2.
- SMEM dequant FP4 → BF16 in vectorized form (
* FP8_scale, 16-element microblocks). - Concatenate
[NoPE, RoPE]in SMEM. - MMA reads contiguous BF16 from SMEM.
Test: FP4+BF16 split input → identical output to pure BF16 input (dequant must be transparent).
Priority 7: hd=512 fix
Blocked. Per Priority 4, multi-CTA grid + head-packed M means decode at hd=512 can route through pv_n_tile=128 and n_k_sub_tiles=2, which compiles fine for hd=256. The hd=512 single-kernel compile is the missing piece for prefill efficiency, not correctness.
Options:
- Pre-compile hd=512 cubin offline (accept 1–2 hour compile if MLIR ever finishes — uncertain).
- Add no-softmax mode emitting raw S to GMEM, call twice for k_sub=0/1, accumulate in Python, softmax once. Two launches but no MLIR hang.
- Write hd=512 path in raw CUTLASS C++. Bypasses CuTeDSL MLIR entirely. Most realistic if NVIDIA can't fix the optimizer.
- Report CuTeDSL MLIR optimizer bug to NVIDIA.
Lower priority than the chain above — at decode T=1, n_h=128, hd=512 the head-packed approach already works without needing a single hd=512 kernel.
Priority 8: Indexer FP4 tensor-core scoring (Stage F)
Paper §5.2.1: "the QK path in the indexer of CSA, where QK activations are cached, loaded, and multiplied entirely in FP4."
Current indexer (dsv4/kernels/cuda/indexer_score_topk.cu): scalar FP32 dot products, no tensor cores, spinlock-protected shared-memory heap. Single largest perf gap in the codebase. At 1M-context decode the indexer scores ~250K compressed entries per query token — the spinlock heap will not scale to top_k=1024.
Target: port DeepGEMM fp8_paged_mqa_logits to FP4 inputs with tcgen05.mma.kind=mxf4nvf4. Plus per-warp partial top-k merged with a final reduction tree (or radix-select). Plus FP32→BF16 score quantization per paper (2× speedup on top-k selector, 99.7% recall).
Scope: 2–3 weeks. Track for Stage F. Do not start until the FP4 epilogue patterns from Priorities 3 and 5 are established — they'll inform the indexer's FP4 load + score paths.
Build order — recommended sequencing
Now ─┬─ Priority 1 (correction epilog rewrite)
│ │
│ └─→ unblocks D1.5, D2 multi-CTA, NVFP4-1.2
│
├─ Priority 3 (NVFP4-1.1 fuse FP4 in SwiGLU) ← parallel, independent
│
↓
Verify hd=64/128/256 regressions hold
│
↓
Priority 2 (Stage E production extraction)
│
↓
Priority 4 (D2 multi-CTA grid)
│
↓
Priority 5 (NVFP4-1.2 fuse FP4 in FMHA output)
│
↓
Priority 6 (NVFP4-2 FP4 KV pipeline)
│
↓
Priority 7 (hd=512 fix — only if prefill efficiency demands it)
│
↓
Priority 8 (indexer FP4 tensor-core scoring) — Stage F
Priority 3 has no dependency on Priorities 1 or 2 and can run on a parallel branch.
Speculative — beyond what the V4 paper validated
Listed for completeness. Do not implement without explicit sign-off.
- NVFP4 compressed KV NOPE dims (paper validated FP8 for compressed KV; FP4 would halve cache again). Risk: compounds quantization noise on already-lossy compressed KV.
- MXFP4 vs NVFP4 for indexer scoring — not validated for indexer specifically.
- NVFP4 for full attention Q×K^T GEMM — closed. Cos 0.86 vs FP32 in earlier tests. Attention stays BF16/FP32.
- Per-token FP8 activation scaling in FMHA — not validated. Out of scope.
- 2:4 structured sparsity on FP4 expert weights — V4 not trained with structured sparsity. Off the table for the released checkpoint.
- NVFP4 LM head + MTP head — big VRAM win (~1.4 GB saved on Pro). Modest quality risk on rare-token logits. Test against held-out eval before shipping.
Key numbers to remember
| Config | n_h | top_k | s_k decode | n_kv_tiles | Multi-tile? |
|---|---|---|---|---|---|
| Flash decode | 64 | 512 | 640 | 5 | YES |
| Pro decode | 128 | 1024 | 1152 | 9 | YES |
| Current single-tile test | 1 | — | 128 | 1 | NO |
Production decode needs the multi-tile path (Priority 1) working in-kernel. Today's Python KV merge ships correct results at the cost of 5–9 launches per step.