- Both load and store atoms built from SAME tOtO_i (composition-tiled) - Same Repetition(corr_tile_size=16) for both copies - pv_done_bar synchronization between MMA and softmax warps - acc_scale computed per kt iteration, used to rescale O in TMEM - const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128 - New test: test_d15_in_kernel_rescale.py (s_k=128/256/384) - Minimal roundtrip test: test_tmem_roundtrip_minimal.py
20 KiB
ROADMAP
Living document. Current state, active blockers, priority order, and what to build next. Architecture and lessons live in README.md — this file is for "what now."
Last updated: 2026-05-26 (revised after correcting D1.5 fix-path analysis)
Current status
Working
| Component | hd | n | cos | Status |
|---|---|---|---|---|
| FMHA TMEM-P | 64 | 128 | 0.999998 | ✅ |
| FMHA TMEM-P / SMEM-P | 128 | 128 | 0.999997 | ✅ |
| FMHA TMEM-P | 256 | 128 | 0.999998 | ✅ |
| FMHA multi-tile (Python KV merge) | 64 | up to 1024 | 0.999998 | ✅ Workaround — see below |
| D3 SWA length mask (in-kernel) | 128 | 128 | 0.999996 | ✅ |
| D4 causal mask on SWA (in-kernel) | 128 | 128 | 0.999996 | ✅ |
| D5c sink merge single-tile | 64 | 128 | 0.999996 | ✅ |
| D5c sink merge multi-tile (Python KV merge) | 64 | 256 | 0.999996 | ✅ |
| Per-head multi-head launch | 64 | 128 | 0.999995 | ✅ n_h=1–128 |
| MoE fused SwiGLU (NVFP4) | — | — | matches ref | ✅ Clamping in kernel |
| Dense router (sqrt-softplus) | — | — | matches ref | ✅ |
| Hash router | — | — | matches ref | ✅ |
use_2cta_instrs conditional |
— | — | 1.7–1.9× speedup | ✅ M≥256 prefill |
| NVFP4 primitives | — | — | E4M3 SF / mxf4nvf4 / 16-elem | ✅ Verified |
Known blockers
| Blocker | Impact | Status |
|---|---|---|
| Per-kt O rescale in TMEM (D1.5) | Multi-tile attention requires 5–9 kernel launches per decode step instead of 1 (Python KV merge) | Workaround is correct (cos 0.999998). Whether to fix depends on profiling — see Priority 1. |
| TMEM final-normalize round-trip | Cannot do in-kernel O /= row_sum cleanly |
Already worked around — emit un-normalized O + LSE, external divide is exact. Not a blocker for shipping. |
epilogue_tma_store blocks D2 multi-CTA + NVFP4-1.2 |
Per-head Python launch wastes 128 launches per Pro decode step; FMHA output forces BF16 GMEM round-trip before wo_a | Unblocked by Priority 2 (the one-way final-epilogue rewrite — which is the only part of the MoE pattern that legitimately ports to FMHA). |
| hd=512 MLIR backend hang | Cannot compile single-kernel hd=512 (>3hr optimizer time, structurally correct) | Decode works via head-packed M with hd≤256 chunks. Single-kernel hd=512 only needed for prefill efficiency. Low priority. |
What the MoE epilogue pattern actually buys us (and what it doesn't)
This deserves stating explicitly because the previous version of this document had it wrong.
The MoE pattern is one-way only. dsv4/kernels/gemm/fused_swiglu.py uses epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition to construct a TMEM → registers → SMEM → GMEM pipeline. There is no corresponding store-back-to-TMEM helper, and no inverse pairing of the t2r atom. The MoE epilogue runs once, after all MMA K-tiles are accumulated. It never needs to mutate the TMEM accumulator mid-loop.
FMHA's per-kt O rescale is structurally different. PV uses tcgen05.mma with ACCUMULATE=True across the kt loop. The accumulator must live in TMEM because that's where MMA reads it and writes it. When row_max changes between kt iterations, the running O accumulator has to be multiplied by acc_scale in TMEM before the next PV — load to registers, multiply, store back. The store-back is the part that's broken: Ld32x32bOp and St32x32bOp built as separate atoms have hardware column mappings that don't match, producing ~3% corruption even on NO-OP round-trip. No software layout transformation in CuTeDSL Python has, so far, made them pair correctly.
Therefore: porting the MoE pattern to FMHA only fixes the one-way paths:
- The final epilogue (after the last kt, when O is being written to GMEM for good).
- Any FP4 amax + pack fusion into that final epilogue.
It does not fix the per-kt rescale. That is a separate, harder problem with three possible paths laid out below.
Priority 1: Profile production decode to determine if the rescale fix is needed
Before investing days in fixing the per-kt rescale, measure whether the 5–9 launch overhead from Python KV merge is actually a bottleneck.
At Pro decode (s_k=1152, n_kv_tiles=9), Python merge dispatches 9 kernels per step. Conservative launch overhead ~50 μs per kernel ≈ 450 μs/step in launch overhead alone. If a full decode step (all 61 layers, MoE, embedding, sampler) takes ~30 ms, that's ~1.5% of latency. If it takes ~10 ms, it's ~4.5%. Whether that's worth a 1–2 week refactor depends on the actual measurement.
Action:
- Profile Pro decode at s_k=1152 with current Python KV merge. Measure: total step latency, launch overhead from FMHA dispatches, FMHA compute time per launch.
- Measure CPU dispatch overhead on the host (Python loop + kernel cache lookup).
- Decision rule: if Python merge overhead is < 5% of total decode latency, defer Priority 8 indefinitely. Ship Python merge as production path.
Done when: there's a profiled number that justifies (or doesn't) the engineering investment in Priority 8.
Priority 2: One-way final-epilogue rewrite
What: Replace the utils.gemm.sm100.epilogue_tma_store(...) call at the end of FMHA (dsv4/kernels/attention/fmha.py lines 565–577) with the MoE-style explicit pipeline:
transform_partitioned_tensor_layout → epilogue_tmem_copy_and_partition →
[register slot — optional normalize/cast/FP4 pack] →
epilogue_smem_copy_and_partition → flat_divide → cpasync.tma_partition → TMA store
This is strictly one-way. The kt-loop rescale code stays exactly as it is (using the broken hand-built atoms — see Priority 8 for the fix path).
What this enables:
- Optional in-kernel normalize. Adds an
if const_expr(self.normalize):block at the register slot to multiply byinv_row_sum. Currently external code does the divide on the un-normalized output. Tiny perf win, not the main reason to do this. - Unblocks NVFP4-1.2 (Priority 6) — gives a register-level modification slot in the FMHA output path where FP4 amax + pack can live, eliminating the BF16 GMEM materialization between FMHA and wo_a.
- Likely unblocks D2 multi-CTA grid (Priority 4) — the current
epilogue_tma_storeis what couldn't accept theflat_divide-based GMEM coordinate system. Switching to the explicitcpasync.tma_partition(tma_c, ..., cute.flat_divide(tCgC_transformed, epi_tile))path puts FMHA on the same TMA pattern MoE uses successfully, which should accept multi-CTA block coordinates.
Caveats to verify on B200 before assuming this works:
- Whether
tma_partitionsurvives inside theif warp_idx < self.mma_warp_idblock. MoE calls it from inside its epilogue warp'sif, but that has not been tested in FMHA's region tree. Previous attempts at the full pattern triggered 20+ minute MLIR compile times before reaching a verdict. - Whether
transform_partitioned_tensor_layoutaccepts FMHA'stOtO0(TMEM iterator with offset) directly, or whether it needs a fresh tensor built attmem_ptr + self.tmem_o0_offsetwithtCtO_fake.layout. - The
epilogue_tmem_copy_and_partitionhelper signature on the current CuTeDSL version — print on B200 before coding.
Failure mode to watch for: if compile hangs as it did previously, this rewrite is genuinely blocked and the chain of follow-ons (Priorities 4 and 6) need alternative paths.
Effort: 1–2 days if the helpers cooperate. Multiple days if the MLIR hang reappears.
Done when:
- hd=64/128/256 regression cos ≥ 0.999998 holds with both
normalize=Trueandnormalize=Falsepaths. - LSE output still matches reference for
normalize=Falsecallers. - Compile time is reasonable (< 5 min) — if not, document the hang and fall back.
Priority 3: NVFP4-1.1 — Fuse FP4 quant into MoE SwiGLU epilogue
Independent of FMHA. Can run in parallel with Priority 2. Biggest bandwidth win in the codebase.
Current:
padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM
↓
quantize_activation_nvfp4 (separate kernel)
↓
padded_activated_fp4 → L2 GEMM
Target:
padded_x_fp4 → L1 GEMM → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM
The SwiGLU + clamp result already lives in registers at tRS_rC.store(acc_vec_bf16) (line 2207 of fused_swiglu.py). That's the slot for amax + FP4 pack.
Per-microblock amax (16 contiguous elements):
shfl_xorbutterfly reduction across the 4 threads holding the 16 elements.- FP8 E4M3 scale = amax / 6 (FP4 e2m1 max).
- Per-element FP4 pack:
sign_bit << 3 | (clamped_val / scale).to(uint3). Two elements → one byte. - 16 packed nibbles → 64-bit word → SMEM stage → TMA store.
- FP8 scale → separate scale-factor SMEM stage → TMA store to the L2 SFA buffer.
Done when: padded_activated_fp4 and padded_activated_x_sf scratch buffers go away, quantize_activation_nvfp4 between L1 and L2 disappears, L1→L2 cosine matches reference.
Priority 4: D2 multi-CTA grid
Depends on Priority 2. Currently per-head Python launch dispatches 128 kernels per Pro decode step. Multi-CTA grid collapses that to 1.
Grid: (num_M_tiles, num_query_heads, batch) — at decode T=1: (1, 128, batch).
Q tensor layout: Option 1 — (batch, n_h, T, head_dim) with head as a TMA mode. Matches CUTLASS reference, allows per-head LSE output, generalizes to GQA later.
MQA K/V sharing: start with independent K/V loads per CTA (each loads its own copy). At decode hd=512, K/V per CTA is ~128 KB; 128 CTAs × 128 KB = 16 MB total, comfortably within HBM bandwidth. Cluster-wide sharing via cluster_shape_mn=(1, num_query_heads, 1) is a future optimization once profiling shows it matters.
Done when: n_h=128, batch=4, T=1 at hd=512 produces correct output with single launch, per-head LSE writes to mLSE[batch, head, m_row] correctly.
Priority 5: Stage E — Production extraction
D5 is complete. Wrap the kernel in a proper interface.
| Step | What | Status |
|---|---|---|
| E1 | File placement: dsv4/kernels/attention/fmha.py |
✅ Done |
| E2 | Constructor signature (head_dim, num_query_heads, sliding_window, top_k, sink/causal flags, dtypes) |
⚠️ Partial — needs cleanup |
| E3 | Call signature: q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o, stream |
⚠️ Needs sink_bias / row_sums integration |
| E4 | Kernel cache + warmup, keyed on (head_dim, num_query_heads, top_k, n_comp, apply_sink_bias, is_causal, ...) |
TODO |
| E5 | torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",)) |
TODO |
| E6 | Reference parity test against FP32 oracle in dsv4/reference/attention.py |
TODO |
| E7 | Cleanup: delete debug test files, keep only tests/unit/test_fmha_kernel.py |
TODO |
Block table, paged KV, FP8 dequant, inv_scale — all handled upstream by the indexer + gather chain. FMHA sees a dense BF16 [T, top_k, head_dim] tile.
Priority 6: NVFP4-1.2 — Fuse FP4 quant into FMHA output → wo_a path
Depends on Priority 2 (uses the register slot in the new final epilogue).
Currently: FMHA emits BF16 → inverse RoPE → BF16 GMEM → wo_a quantizes to FP4.
Target: register slot in FMHA's new final epilogue does O / row_sum and inverse RoPE rotation and per-microblock amax and FP4 pack. wo_a reads FP4 directly with no GMEM materialization.
Same pattern as Priority 3, different home (FMHA final epilogue, not MoE epilogue).
Priority 7: NVFP4-2 — FP4 KV pipeline depth in FMHA
Depends on Priority 2 being solid at BF16 KV first. FP4 KV shrinks tiles ~4×; same SMEM budget supports more pipeline stages.
| KV dtype | Tile size (hd=512) | Stages fitting 192 KB |
|---|---|---|
| BF16 | 128 KB | 2 |
| FP8 | 64 KB | 4 |
| FP4 | ~36 KB | 6 |
At 1M-context decode where KV reads dominate, deeper pipelines hide more TMA latency.
Implementation:
- TMA loads FP4 NoPE dims (
e2m1_x2packed) to SMEM slot 0. - TMA loads BF16 RoPE dims to SMEM slot 1.
- TMA loads FP8 scale factors to SMEM slot 2.
- SMEM dequant FP4 → BF16 vectorized (
* FP8_scale, 16-element microblocks). - Concatenate
[NoPE, RoPE]in SMEM. - MMA reads contiguous BF16 from SMEM.
Test: FP4+BF16 split input → identical output to pure BF16 input. Dequant must be transparent.
Priority 8 (conditional on P1 profile): Per-kt O rescale fix
Only justified if Priority 1 profiling shows Python KV merge overhead > 5% of decode latency. Otherwise defer indefinitely — the current correct workaround ships.
There are three possible paths. None is a small change.
Path A: CUTLASS atom replication
The CUTLASS C++ Blackwell FMHA does a TMEM round-trip for O rescale using SM100_TMEM_LOAD_32dp32b_4x_atom and SM100_TMEM_STORE_32dp32b_4x_atom, which are paired by hardware design. So in principle TMEM round-trip is possible. The question is whether CuTeDSL Python exposes the specific atom variants and layout configuration CUTLASS uses, and whether they can be paired correctly through make_tmem_copy.
Steps:
- Read
/root/cutlass/.../blackwell/kernel/attention/fmha/fmha.py(or equivalent C++ reference) and document the exact atom + repetition + tensor layout used forcorrection_rescale. - Enumerate what CuTeDSL Python exposes:
dir(tcgen05.copy), availableLdNxNbOp/StNxNbOpvariants, whatRepetition(N)controls. - Identify the difference between current
Ld32x32bOp(Repetition(16))+St32x32bOp(Repetition(16))and whatever CUTLASS uses. - Build a minimal NO-OP round-trip test (load O, store back unchanged) with the candidate atom configuration. Verify cos = 1.0.
- If NO-OP passes, retest with
* acc_scalemodification.
Risk: CuTeDSL Python may not expose the necessary atom variants, or may not allow the layout configuration CUTLASS uses. In that case, escalate to Path B or C.
Effort if it works: 2–4 days investigation + 1–2 days porting.
Path B: O accumulator in registers, manual PV
Restructure FMHA so the PV accumulator is register-resident, not TMEM-resident. Each kt: read V from SMEM, read P from TMEM/SMEM, compute one-shot PV (no accumulate) writing to a temporary TMEM region, then load to registers and add to register-resident running O (with acc_scale applied to the running O before the add).
Implications:
- Register pressure is severe at hd=512. 512 FP32 per row × 1 row per thread × 128 threads = 128 KB of registers just for O. Possible but tight.
- PV without TMEM accumulate is a non-standard MMA usage. May need to use a smaller PV tile and accumulate in registers across sub-tiles.
- Loses the natural MMA-accumulator pipeline overlap.
Effort: 1–2 weeks. High risk of regressing hd=64/128/256 paths during the refactor.
Path C: O accumulator in SMEM
Variant of Path B with O in SMEM instead of registers. PV writes to a temporary TMEM region, gets loaded to registers, applies acc_scale * existing_SMEM_O + new_PV, stores back to SMEM. Final epilogue reads from SMEM.
Implications:
- SMEM budget tightens significantly (need O in SMEM = ~64 KB at hd=512).
- Adds SMEM read/write pressure on every kt.
- May require dropping kv_stage to 1 across the board.
Effort: ~1 week. Lower risk than Path B but bigger perf impact (more SMEM traffic).
Recommended order if profiling demands a fix
- Try Path A first — least invasive, may just work with the right atom config.
- If Path A confirmed impossible in CuTeDSL Python, try Path C (SMEM-resident O) before Path B (register-resident O). SMEM gives more headroom at hd=512.
- Path B only if both fail and the perf gap is truly critical.
Priority 9: hd=512 single-kernel fix
Currently blocked. MLIR optimizer hangs > 3 hours on the hd=512 kernel. Tracer completes in 0.8s — kernel is structurally correct.
Decode works via head-packed M with pv_n_tile=128 and n_k_sub_tiles=2, so this is only a prefill efficiency issue. Lower priority than the chain above.
Options:
- Pre-compile cubin offline (accept 1–2 hour compile, cache result).
- Add no-softmax mode emitting raw S to GMEM; call twice for k_sub=0/1, accumulate in Python, softmax once externally.
- Write hd=512 path in raw CUTLASS C++. Bypasses CuTeDSL MLIR entirely. Most realistic if NVIDIA can't fix the optimizer.
- Report CuTeDSL MLIR optimizer bug to NVIDIA.
Priority 10: Indexer FP4 tensor-core scoring (Stage F)
Paper §5.2.1: "the QK path in the indexer of CSA, where QK activations are cached, loaded, and multiplied entirely in FP4."
Current indexer (dsv4/kernels/cuda/indexer_score_topk.cu): scalar FP32 dot products, no tensor cores, spinlock-protected shared-memory heap. Single largest perf gap in the codebase. At 1M-context decode it scores ~250K compressed entries per query token — the spinlock heap will not scale to top_k=1024.
Target: port DeepGEMM fp8_paged_mqa_logits to FP4 inputs with tcgen05.mma.kind=mxf4nvf4. Plus per-warp partial top-k merged with a final reduction tree (or radix-select). Plus FP32→BF16 score quantization (paper claims 2× speedup on top-k selector at 99.7% recall).
Scope: 2–3 weeks. Stage F. Do not start until the FP4 epilogue patterns from Priorities 3 and 6 are established — they inform the indexer's FP4 load + score paths.
Build order — recommended sequencing
Priority 1 (PROFILE) ──► gates Priority 8
│
Priority 2 (one-way ─┼─► unblocks Priority 4 (multi-CTA)
final epilogue) │ unblocks Priority 6 (FP4 fuse in FMHA)
│
Priority 3 (NVFP4-1.1) ─┴── parallel, independent
[verify hd regressions]
│
▼
Priority 4 (D2 multi-CTA grid)
│
▼
Priority 5 (Stage E production extraction)
│
▼
Priority 6 (NVFP4-1.2 FP4 fuse in FMHA output)
│
▼
Priority 7 (NVFP4-2 FP4 KV pipeline)
│
▼
Priority 8 (per-kt rescale fix — ONLY if P1 says it matters)
│
▼
Priority 9 (hd=512 fix — only if prefill efficiency demands)
│
▼
Priority 10 (indexer FP4 tensor-core scoring) — Stage F
Key change from the previous version: Priority 1 is now a profiling task, not an engineering task. Priority 2 is scoped honestly — it's the one-way path only, and it does not fix the per-kt rescale. The per-kt rescale is Priority 8, conditional, and three paths deep because there is no easy fix.
Speculative — beyond what the V4 paper validated
Listed for completeness. Do not implement without explicit sign-off.
- NVFP4 compressed KV NoPE dims. Paper validated FP8; FP4 would halve cache again. Risk: compounds quantization noise on already-lossy compressed KV.
- MXFP4 vs NVFP4 for indexer scoring. Not validated for indexer specifically.
- NVFP4 for full attention Q×K^T GEMM. Closed. Cos 0.86 vs FP32 in earlier tests. Attention stays BF16/FP32.
- Per-token FP8 activation scaling in FMHA. Not validated. Out of scope.
- 2:4 structured sparsity on FP4 expert weights. V4 not trained with structured sparsity. Off the table for the released checkpoint.
- NVFP4 LM head + MTP head. Big VRAM win (~1.4 GB saved on Pro). Modest quality risk on rare-token logits. Test against held-out eval before shipping.
Key numbers to remember
| Config | n_h | top_k | s_k decode | n_kv_tiles | Multi-tile? |
|---|---|---|---|---|---|
| Flash decode | 64 | 512 | 640 | 5 | YES |
| Pro decode | 128 | 1024 | 1152 | 9 | YES |
| Current single-tile test | 1 | — | 128 | 1 | NO |
Production decode needs the multi-tile path. Today's Python KV merge ships correct results at the cost of 5–9 launches per step. Whether that cost matters is what Priority 1 measures.