# ROADMAP Living document. Current state, active blockers, priority order, and what to build next. Architecture and lessons live in README.md — this file is for "what now." **Last updated:** 2026-05-26 (revised after correcting D1.5 fix-path analysis) --- ## Current status ### Working | Component | hd | n | cos | Status | |---|---:|---:|---:|---| | FMHA TMEM-P | 64 | 128 | 0.999998 | ✅ | | FMHA TMEM-P / SMEM-P | 128 | 128 | 0.999997 | ✅ | | FMHA TMEM-P | 256 | 128 | 0.999998 | ✅ | | FMHA multi-tile (Python KV merge) | 64 | up to 1024 | 0.999998 | ✅ Workaround — see below | | D3 SWA length mask (in-kernel) | 128 | 128 | 0.999996 | ✅ | | D4 causal mask on SWA (in-kernel) | 128 | 128 | 0.999996 | ✅ | | D5c sink merge single-tile | 64 | 128 | 0.999996 | ✅ | | D5c sink merge multi-tile (Python KV merge) | 64 | 256 | 0.999996 | ✅ | | Per-head multi-head launch | 64 | 128 | 0.999995 | ✅ n_h=1–128 | | MoE fused SwiGLU (NVFP4) | — | — | matches ref | ✅ Clamping in kernel | | Dense router (sqrt-softplus) | — | — | matches ref | ✅ | | Hash router | — | — | matches ref | ✅ | | `use_2cta_instrs` conditional | — | — | 1.7–1.9× speedup | ✅ M≥256 prefill | | NVFP4 primitives | — | — | E4M3 SF / mxf4nvf4 / 16-elem | ✅ Verified | ### Known blockers | Blocker | Impact | Status | |---|---|---| | **Per-kt O rescale in TMEM (D1.5)** | Multi-tile attention requires 5–9 kernel launches per decode step instead of 1 (Python KV merge) | Workaround is correct (cos 0.999998). Whether to fix depends on profiling — see Priority 1. | | **TMEM final-normalize round-trip** | Cannot do in-kernel `O /= row_sum` cleanly | **Already worked around** — emit un-normalized O + LSE, external divide is exact. Not a blocker for shipping. | | **`epilogue_tma_store` blocks D2 multi-CTA + NVFP4-1.2** | Per-head Python launch wastes 128 launches per Pro decode step; FMHA output forces BF16 GMEM round-trip before wo_a | Unblocked by Priority 2 (the one-way final-epilogue rewrite — which is the **only** part of the MoE pattern that legitimately ports to FMHA). | | **hd=512 MLIR backend hang** | Cannot compile single-kernel hd=512 (>3hr optimizer time, structurally correct) | Decode works via head-packed M with hd≤256 chunks. Single-kernel hd=512 only needed for prefill efficiency. Low priority. | --- ## What the MoE epilogue pattern actually buys us (and what it doesn't) This deserves stating explicitly because the previous version of this document had it wrong. **The MoE pattern is one-way only.** `dsv4/kernels/gemm/fused_swiglu.py` uses `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` to construct a TMEM → registers → SMEM → GMEM pipeline. There is **no** corresponding store-back-to-TMEM helper, and no inverse pairing of the t2r atom. The MoE epilogue runs *once*, *after* all MMA K-tiles are accumulated. It never needs to mutate the TMEM accumulator mid-loop. **FMHA's per-kt O rescale is structurally different.** PV uses `tcgen05.mma` with `ACCUMULATE=True` across the kt loop. The accumulator must live in TMEM because that's where MMA reads it and writes it. When row_max changes between kt iterations, the running O accumulator has to be multiplied by `acc_scale` *in TMEM* before the next PV — load to registers, multiply, store back. The store-back is the part that's broken: `Ld32x32bOp` and `St32x32bOp` built as separate atoms have hardware column mappings that don't match, producing ~3% corruption even on NO-OP round-trip. No software layout transformation in CuTeDSL Python has, so far, made them pair correctly. **Therefore:** porting the MoE pattern to FMHA **only fixes the one-way paths**: 1. The final epilogue (after the last kt, when O is being written to GMEM for good). 2. Any FP4 amax + pack fusion into that final epilogue. It does **not** fix the per-kt rescale. That is a separate, harder problem with three possible paths laid out below. --- ## Priority 1: Profile production decode to determine if the rescale fix is needed Before investing days in fixing the per-kt rescale, measure whether the 5–9 launch overhead from Python KV merge is actually a bottleneck. At Pro decode (s_k=1152, n_kv_tiles=9), Python merge dispatches 9 kernels per step. Conservative launch overhead ~50 μs per kernel ≈ 450 μs/step in launch overhead alone. If a full decode step (all 61 layers, MoE, embedding, sampler) takes ~30 ms, that's ~1.5% of latency. If it takes ~10 ms, it's ~4.5%. Whether that's worth a 1–2 week refactor depends on the actual measurement. **Action:** - [ ] Profile Pro decode at s_k=1152 with current Python KV merge. Measure: total step latency, launch overhead from FMHA dispatches, FMHA compute time per launch. - [ ] Measure CPU dispatch overhead on the host (Python loop + kernel cache lookup). - [ ] Decision rule: if Python merge overhead is < 5% of total decode latency, defer Priority 8 indefinitely. Ship Python merge as production path. **Done when:** there's a profiled number that justifies (or doesn't) the engineering investment in Priority 8. --- ## Priority 2: One-way final-epilogue rewrite **What:** Replace the `utils.gemm.sm100.epilogue_tma_store(...)` call at the end of FMHA (`dsv4/kernels/attention/fmha.py` lines 565–577) with the MoE-style explicit pipeline: ``` transform_partitioned_tensor_layout → epilogue_tmem_copy_and_partition → [register slot — optional normalize/cast/FP4 pack] → epilogue_smem_copy_and_partition → flat_divide → cpasync.tma_partition → TMA store ``` This is **strictly one-way**. The kt-loop rescale code stays exactly as it is (using the broken hand-built atoms — see Priority 8 for the fix path). **What this enables:** 1. **Optional in-kernel normalize.** Adds an `if const_expr(self.normalize):` block at the register slot to multiply by `inv_row_sum`. Currently external code does the divide on the un-normalized output. Tiny perf win, not the main reason to do this. 2. **Unblocks NVFP4-1.2** (Priority 6) — gives a register-level modification slot in the FMHA output path where FP4 amax + pack can live, eliminating the BF16 GMEM materialization between FMHA and wo_a. 3. **Likely unblocks D2 multi-CTA grid** (Priority 4) — the current `epilogue_tma_store` is what couldn't accept the `flat_divide`-based GMEM coordinate system. Switching to the explicit `cpasync.tma_partition(tma_c, ..., cute.flat_divide(tCgC_transformed, epi_tile))` path puts FMHA on the same TMA pattern MoE uses successfully, which should accept multi-CTA block coordinates. **Caveats to verify on B200 before assuming this works:** - Whether `tma_partition` survives inside the `if warp_idx < self.mma_warp_id` block. MoE calls it from inside its epilogue warp's `if`, but that has not been tested in FMHA's region tree. Previous attempts at the full pattern triggered 20+ minute MLIR compile times before reaching a verdict. - Whether `transform_partitioned_tensor_layout` accepts FMHA's `tOtO0` (TMEM iterator with offset) directly, or whether it needs a fresh tensor built at `tmem_ptr + self.tmem_o0_offset` with `tCtO_fake.layout`. - The `epilogue_tmem_copy_and_partition` helper signature on the current CuTeDSL version — print on B200 before coding. **Failure mode to watch for:** if compile hangs as it did previously, this rewrite is genuinely blocked and the chain of follow-ons (Priorities 4 and 6) need alternative paths. **Effort:** 1–2 days if the helpers cooperate. Multiple days if the MLIR hang reappears. **Done when:** - hd=64/128/256 regression cos ≥ 0.999998 holds with both `normalize=True` and `normalize=False` paths. - LSE output still matches reference for `normalize=False` callers. - Compile time is reasonable (< 5 min) — if not, document the hang and fall back. --- ## Priority 3: NVFP4-1.1 — Fuse FP4 quant into MoE SwiGLU epilogue **Independent of FMHA. Can run in parallel with Priority 2.** Biggest bandwidth win in the codebase. Current: ``` padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM ↓ quantize_activation_nvfp4 (separate kernel) ↓ padded_activated_fp4 → L2 GEMM ``` Target: ``` padded_x_fp4 → L1 GEMM → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM ``` The SwiGLU + clamp result already lives in registers at `tRS_rC.store(acc_vec_bf16)` (line 2207 of `fused_swiglu.py`). That's the slot for amax + FP4 pack. **Per-microblock amax (16 contiguous elements):** 1. `shfl_xor` butterfly reduction across the 4 threads holding the 16 elements. 2. FP8 E4M3 scale = amax / 6 (FP4 e2m1 max). 3. Per-element FP4 pack: `sign_bit << 3 | (clamped_val / scale).to(uint3)`. Two elements → one byte. 4. 16 packed nibbles → 64-bit word → SMEM stage → TMA store. 5. FP8 scale → separate scale-factor SMEM stage → TMA store to the L2 SFA buffer. **Done when:** `padded_activated_fp4` and `padded_activated_x_sf` scratch buffers go away, `quantize_activation_nvfp4` between L1 and L2 disappears, L1→L2 cosine matches reference. --- ## Priority 4: D2 multi-CTA grid **Depends on Priority 2.** Currently per-head Python launch dispatches 128 kernels per Pro decode step. Multi-CTA grid collapses that to 1. **Grid:** `(num_M_tiles, num_query_heads, batch)` — at decode T=1: `(1, 128, batch)`. **Q tensor layout:** Option 1 — `(batch, n_h, T, head_dim)` with head as a TMA mode. Matches CUTLASS reference, allows per-head LSE output, generalizes to GQA later. **MQA K/V sharing:** start with independent K/V loads per CTA (each loads its own copy). At decode hd=512, K/V per CTA is ~128 KB; 128 CTAs × 128 KB = 16 MB total, comfortably within HBM bandwidth. Cluster-wide sharing via `cluster_shape_mn=(1, num_query_heads, 1)` is a future optimization once profiling shows it matters. **Done when:** `n_h=128, batch=4, T=1` at hd=512 produces correct output with single launch, per-head LSE writes to `mLSE[batch, head, m_row]` correctly. --- ## Priority 5: Stage E — Production extraction D5 is complete. Wrap the kernel in a proper interface. | Step | What | Status | |---|---|---| | E1 | File placement: `dsv4/kernels/attention/fmha.py` | ✅ Done | | E2 | Constructor signature (`head_dim`, `num_query_heads`, `sliding_window`, `top_k`, sink/causal flags, dtypes) | ⚠️ Partial — needs cleanup | | E3 | Call signature: `q`, `compressed_kv`, `swa_kv`, `swa_lens`, `sink_logits`, `request_ids`, `o`, `stream` | ⚠️ Needs sink_bias / row_sums integration | | E4 | Kernel cache + warmup, keyed on `(head_dim, num_query_heads, top_k, n_comp, apply_sink_bias, is_causal, ...)` | TODO | | E5 | `torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",))` | TODO | | E6 | Reference parity test against FP32 oracle in `dsv4/reference/attention.py` | TODO | | E7 | Cleanup: delete debug test files, keep only `tests/unit/test_fmha_kernel.py` | TODO | Block table, paged KV, FP8 dequant, inv_scale — all handled upstream by the indexer + gather chain. FMHA sees a dense BF16 `[T, top_k, head_dim]` tile. --- ## Priority 6: NVFP4-1.2 — Fuse FP4 quant into FMHA output → wo_a path **Depends on Priority 2** (uses the register slot in the new final epilogue). Currently: FMHA emits BF16 → inverse RoPE → BF16 GMEM → wo_a quantizes to FP4. Target: register slot in FMHA's new final epilogue does `O / row_sum` *and* inverse RoPE rotation *and* per-microblock amax *and* FP4 pack. wo_a reads FP4 directly with no GMEM materialization. Same pattern as Priority 3, different home (FMHA final epilogue, not MoE epilogue). --- ## Priority 7: NVFP4-2 — FP4 KV pipeline depth in FMHA **Depends on Priority 2 being solid at BF16 KV first.** FP4 KV shrinks tiles ~4×; same SMEM budget supports more pipeline stages. | KV dtype | Tile size (hd=512) | Stages fitting 192 KB | |---|---:|---:| | BF16 | 128 KB | 2 | | FP8 | 64 KB | 4 | | FP4 | ~36 KB | 6 | At 1M-context decode where KV reads dominate, deeper pipelines hide more TMA latency. **Implementation:** - TMA loads FP4 NoPE dims (`e2m1_x2` packed) to SMEM slot 0. - TMA loads BF16 RoPE dims to SMEM slot 1. - TMA loads FP8 scale factors to SMEM slot 2. - SMEM dequant FP4 → BF16 vectorized (`* FP8_scale`, 16-element microblocks). - Concatenate `[NoPE, RoPE]` in SMEM. - MMA reads contiguous BF16 from SMEM. **Test:** FP4+BF16 split input → identical output to pure BF16 input. Dequant must be transparent. --- ## Priority 8 (conditional on P1 profile): Per-kt O rescale fix **Only justified if Priority 1 profiling shows Python KV merge overhead > 5% of decode latency.** Otherwise defer indefinitely — the current correct workaround ships. There are three possible paths. None is a small change. ### Path A: CUTLASS atom replication The CUTLASS C++ Blackwell FMHA does a TMEM round-trip for O rescale using `SM100_TMEM_LOAD_32dp32b_4x_atom` and `SM100_TMEM_STORE_32dp32b_4x_atom`, which are paired by hardware design. So in principle TMEM round-trip is possible. The question is whether CuTeDSL Python exposes the specific atom variants and layout configuration CUTLASS uses, and whether they can be paired correctly through `make_tmem_copy`. **Steps:** - [ ] Read `/root/cutlass/.../blackwell/kernel/attention/fmha/fmha.py` (or equivalent C++ reference) and document the exact atom + repetition + tensor layout used for `correction_rescale`. - [ ] Enumerate what CuTeDSL Python exposes: `dir(tcgen05.copy)`, available `LdNxNbOp` / `StNxNbOp` variants, what `Repetition(N)` controls. - [ ] Identify the difference between current `Ld32x32bOp(Repetition(16))` + `St32x32bOp(Repetition(16))` and whatever CUTLASS uses. - [ ] Build a minimal NO-OP round-trip test (load O, store back unchanged) with the candidate atom configuration. Verify cos = 1.0. - [ ] If NO-OP passes, retest with `* acc_scale` modification. **Risk:** CuTeDSL Python may not expose the necessary atom variants, or may not allow the layout configuration CUTLASS uses. In that case, escalate to Path B or C. **Effort if it works:** 2–4 days investigation + 1–2 days porting. ### Path B: O accumulator in registers, manual PV Restructure FMHA so the PV accumulator is register-resident, not TMEM-resident. Each kt: read V from SMEM, read P from TMEM/SMEM, compute one-shot PV (no accumulate) writing to a temporary TMEM region, then load to registers and add to register-resident running O (with acc_scale applied to the running O before the add). **Implications:** - Register pressure is severe at hd=512. 512 FP32 per row × 1 row per thread × 128 threads = 128 KB of registers just for O. Possible but tight. - PV without TMEM accumulate is a non-standard MMA usage. May need to use a smaller PV tile and accumulate in registers across sub-tiles. - Loses the natural MMA-accumulator pipeline overlap. **Effort:** 1–2 weeks. High risk of regressing hd=64/128/256 paths during the refactor. ### Path C: O accumulator in SMEM Variant of Path B with O in SMEM instead of registers. PV writes to a temporary TMEM region, gets loaded to registers, applies acc_scale * existing_SMEM_O + new_PV, stores back to SMEM. Final epilogue reads from SMEM. **Implications:** - SMEM budget tightens significantly (need O in SMEM = ~64 KB at hd=512). - Adds SMEM read/write pressure on every kt. - May require dropping kv_stage to 1 across the board. **Effort:** ~1 week. Lower risk than Path B but bigger perf impact (more SMEM traffic). ### Recommended order if profiling demands a fix 1. Try Path A first — least invasive, may just work with the right atom config. 2. If Path A confirmed impossible in CuTeDSL Python, try Path C (SMEM-resident O) before Path B (register-resident O). SMEM gives more headroom at hd=512. 3. Path B only if both fail and the perf gap is truly critical. --- ## Priority 9: hd=512 single-kernel fix **Currently blocked.** MLIR optimizer hangs > 3 hours on the hd=512 kernel. Tracer completes in 0.8s — kernel is structurally correct. Decode works via head-packed M with `pv_n_tile=128` and `n_k_sub_tiles=2`, so this is only a prefill efficiency issue. Lower priority than the chain above. **Options:** 1. Pre-compile cubin offline (accept 1–2 hour compile, cache result). 2. Add no-softmax mode emitting raw S to GMEM; call twice for k_sub=0/1, accumulate in Python, softmax once externally. 3. Write hd=512 path in raw CUTLASS C++. Bypasses CuTeDSL MLIR entirely. Most realistic if NVIDIA can't fix the optimizer. 4. Report CuTeDSL MLIR optimizer bug to NVIDIA. --- ## Priority 10: Indexer FP4 tensor-core scoring (Stage F) Paper §5.2.1: *"the QK path in the indexer of CSA, where QK activations are cached, loaded, and multiplied entirely in FP4."* Current indexer (`dsv4/kernels/cuda/indexer_score_topk.cu`): scalar FP32 dot products, no tensor cores, spinlock-protected shared-memory heap. Single largest perf gap in the codebase. At 1M-context decode it scores ~250K compressed entries per query token — the spinlock heap will not scale to top_k=1024. **Target:** port DeepGEMM `fp8_paged_mqa_logits` to FP4 inputs with `tcgen05.mma.kind=mxf4nvf4`. Plus per-warp partial top-k merged with a final reduction tree (or radix-select). Plus FP32→BF16 score quantization (paper claims 2× speedup on top-k selector at 99.7% recall). **Scope:** 2–3 weeks. Stage F. Do not start until the FP4 epilogue patterns from Priorities 3 and 6 are established — they inform the indexer's FP4 load + score paths. --- ## Build order — recommended sequencing ``` Priority 1 (PROFILE) ──► gates Priority 8 │ Priority 2 (one-way ─┼─► unblocks Priority 4 (multi-CTA) final epilogue) │ unblocks Priority 6 (FP4 fuse in FMHA) │ Priority 3 (NVFP4-1.1) ─┴── parallel, independent [verify hd regressions] │ ▼ Priority 4 (D2 multi-CTA grid) │ ▼ Priority 5 (Stage E production extraction) │ ▼ Priority 6 (NVFP4-1.2 FP4 fuse in FMHA output) │ ▼ Priority 7 (NVFP4-2 FP4 KV pipeline) │ ▼ Priority 8 (per-kt rescale fix — ONLY if P1 says it matters) │ ▼ Priority 9 (hd=512 fix — only if prefill efficiency demands) │ ▼ Priority 10 (indexer FP4 tensor-core scoring) — Stage F ``` **Key change from the previous version:** Priority 1 is now a profiling task, not an engineering task. Priority 2 is scoped honestly — it's the one-way path only, and it does *not* fix the per-kt rescale. The per-kt rescale is Priority 8, conditional, and three paths deep because there is no easy fix. --- ## Speculative — beyond what the V4 paper validated Listed for completeness. **Do not implement without explicit sign-off.** 1. **NVFP4 compressed KV NoPE dims.** Paper validated FP8; FP4 would halve cache again. Risk: compounds quantization noise on already-lossy compressed KV. 2. **MXFP4 vs NVFP4 for indexer scoring.** Not validated for indexer specifically. 3. **NVFP4 for full attention Q×K^T GEMM.** Closed. Cos 0.86 vs FP32 in earlier tests. Attention stays BF16/FP32. 4. **Per-token FP8 activation scaling in FMHA.** Not validated. Out of scope. 5. **2:4 structured sparsity on FP4 expert weights.** V4 not trained with structured sparsity. Off the table for the released checkpoint. 6. **NVFP4 LM head + MTP head.** Big VRAM win (~1.4 GB saved on Pro). Modest quality risk on rare-token logits. Test against held-out eval before shipping. --- ## Key numbers to remember | Config | n_h | top_k | s_k decode | n_kv_tiles | Multi-tile? | |---|---:|---:|---:|---:|:---| | Flash decode | 64 | 512 | 640 | 5 | YES | | Pro decode | 128 | 1024 | 1152 | 9 | YES | | Current single-tile test | 1 | — | 128 | 1 | NO | Production decode needs the multi-tile path. Today's Python KV merge ships correct results at the cost of 5–9 launches per step. Whether that cost matters is what Priority 1 measures.