Files
nvfp4-megamoe-kernel/ROADMAP.md
biondizzle bf2c7c8bb8 D1.5: Implement in-kernel O rescale via CUTLASS correction_rescale pattern
- Both load and store atoms built from SAME tOtO_i (composition-tiled)
- Same Repetition(corr_tile_size=16) for both copies
- pv_done_bar synchronization between MMA and softmax warps
- acc_scale computed per kt iteration, used to rescale O in TMEM
- const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128
- New test: test_d15_in_kernel_rescale.py (s_k=128/256/384)
- Minimal roundtrip test: test_tmem_roundtrip_minimal.py
2026-05-26 20:26:06 +00:00

20 KiB
Raw Blame History

ROADMAP

Living document. Current state, active blockers, priority order, and what to build next. Architecture and lessons live in README.md — this file is for "what now."

Last updated: 2026-05-26 (revised after correcting D1.5 fix-path analysis)


Current status

Working

Component hd n cos Status
FMHA TMEM-P 64 128 0.999998
FMHA TMEM-P / SMEM-P 128 128 0.999997
FMHA TMEM-P 256 128 0.999998
FMHA multi-tile (Python KV merge) 64 up to 1024 0.999998 Workaround — see below
D3 SWA length mask (in-kernel) 128 128 0.999996
D4 causal mask on SWA (in-kernel) 128 128 0.999996
D5c sink merge single-tile 64 128 0.999996
D5c sink merge multi-tile (Python KV merge) 64 256 0.999996
Per-head multi-head launch 64 128 0.999995 n_h=1128
MoE fused SwiGLU (NVFP4) matches ref Clamping in kernel
Dense router (sqrt-softplus) matches ref
Hash router matches ref
use_2cta_instrs conditional 1.71.9× speedup M≥256 prefill
NVFP4 primitives E4M3 SF / mxf4nvf4 / 16-elem Verified

Known blockers

Blocker Impact Status
Per-kt O rescale in TMEM (D1.5) Multi-tile attention requires 59 kernel launches per decode step instead of 1 (Python KV merge) Workaround is correct (cos 0.999998). Whether to fix depends on profiling — see Priority 1.
TMEM final-normalize round-trip Cannot do in-kernel O /= row_sum cleanly Already worked around — emit un-normalized O + LSE, external divide is exact. Not a blocker for shipping.
epilogue_tma_store blocks D2 multi-CTA + NVFP4-1.2 Per-head Python launch wastes 128 launches per Pro decode step; FMHA output forces BF16 GMEM round-trip before wo_a Unblocked by Priority 2 (the one-way final-epilogue rewrite — which is the only part of the MoE pattern that legitimately ports to FMHA).
hd=512 MLIR backend hang Cannot compile single-kernel hd=512 (>3hr optimizer time, structurally correct) Decode works via head-packed M with hd≤256 chunks. Single-kernel hd=512 only needed for prefill efficiency. Low priority.

What the MoE epilogue pattern actually buys us (and what it doesn't)

This deserves stating explicitly because the previous version of this document had it wrong.

The MoE pattern is one-way only. dsv4/kernels/gemm/fused_swiglu.py uses epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition to construct a TMEM → registers → SMEM → GMEM pipeline. There is no corresponding store-back-to-TMEM helper, and no inverse pairing of the t2r atom. The MoE epilogue runs once, after all MMA K-tiles are accumulated. It never needs to mutate the TMEM accumulator mid-loop.

FMHA's per-kt O rescale is structurally different. PV uses tcgen05.mma with ACCUMULATE=True across the kt loop. The accumulator must live in TMEM because that's where MMA reads it and writes it. When row_max changes between kt iterations, the running O accumulator has to be multiplied by acc_scale in TMEM before the next PV — load to registers, multiply, store back. The store-back is the part that's broken: Ld32x32bOp and St32x32bOp built as separate atoms have hardware column mappings that don't match, producing ~3% corruption even on NO-OP round-trip. No software layout transformation in CuTeDSL Python has, so far, made them pair correctly.

Therefore: porting the MoE pattern to FMHA only fixes the one-way paths:

  1. The final epilogue (after the last kt, when O is being written to GMEM for good).
  2. Any FP4 amax + pack fusion into that final epilogue.

It does not fix the per-kt rescale. That is a separate, harder problem with three possible paths laid out below.


Priority 1: Profile production decode to determine if the rescale fix is needed

Before investing days in fixing the per-kt rescale, measure whether the 59 launch overhead from Python KV merge is actually a bottleneck.

At Pro decode (s_k=1152, n_kv_tiles=9), Python merge dispatches 9 kernels per step. Conservative launch overhead ~50 μs per kernel ≈ 450 μs/step in launch overhead alone. If a full decode step (all 61 layers, MoE, embedding, sampler) takes ~30 ms, that's ~1.5% of latency. If it takes ~10 ms, it's ~4.5%. Whether that's worth a 12 week refactor depends on the actual measurement.

Action:

  • Profile Pro decode at s_k=1152 with current Python KV merge. Measure: total step latency, launch overhead from FMHA dispatches, FMHA compute time per launch.
  • Measure CPU dispatch overhead on the host (Python loop + kernel cache lookup).
  • Decision rule: if Python merge overhead is < 5% of total decode latency, defer Priority 8 indefinitely. Ship Python merge as production path.

Done when: there's a profiled number that justifies (or doesn't) the engineering investment in Priority 8.


Priority 2: One-way final-epilogue rewrite

What: Replace the utils.gemm.sm100.epilogue_tma_store(...) call at the end of FMHA (dsv4/kernels/attention/fmha.py lines 565577) with the MoE-style explicit pipeline:

transform_partitioned_tensor_layout → epilogue_tmem_copy_and_partition →
[register slot — optional normalize/cast/FP4 pack] →
epilogue_smem_copy_and_partition → flat_divide → cpasync.tma_partition → TMA store

This is strictly one-way. The kt-loop rescale code stays exactly as it is (using the broken hand-built atoms — see Priority 8 for the fix path).

What this enables:

  1. Optional in-kernel normalize. Adds an if const_expr(self.normalize): block at the register slot to multiply by inv_row_sum. Currently external code does the divide on the un-normalized output. Tiny perf win, not the main reason to do this.
  2. Unblocks NVFP4-1.2 (Priority 6) — gives a register-level modification slot in the FMHA output path where FP4 amax + pack can live, eliminating the BF16 GMEM materialization between FMHA and wo_a.
  3. Likely unblocks D2 multi-CTA grid (Priority 4) — the current epilogue_tma_store is what couldn't accept the flat_divide-based GMEM coordinate system. Switching to the explicit cpasync.tma_partition(tma_c, ..., cute.flat_divide(tCgC_transformed, epi_tile)) path puts FMHA on the same TMA pattern MoE uses successfully, which should accept multi-CTA block coordinates.

Caveats to verify on B200 before assuming this works:

  • Whether tma_partition survives inside the if warp_idx < self.mma_warp_id block. MoE calls it from inside its epilogue warp's if, but that has not been tested in FMHA's region tree. Previous attempts at the full pattern triggered 20+ minute MLIR compile times before reaching a verdict.
  • Whether transform_partitioned_tensor_layout accepts FMHA's tOtO0 (TMEM iterator with offset) directly, or whether it needs a fresh tensor built at tmem_ptr + self.tmem_o0_offset with tCtO_fake.layout.
  • The epilogue_tmem_copy_and_partition helper signature on the current CuTeDSL version — print on B200 before coding.

Failure mode to watch for: if compile hangs as it did previously, this rewrite is genuinely blocked and the chain of follow-ons (Priorities 4 and 6) need alternative paths.

Effort: 12 days if the helpers cooperate. Multiple days if the MLIR hang reappears.

Done when:

  • hd=64/128/256 regression cos ≥ 0.999998 holds with both normalize=True and normalize=False paths.
  • LSE output still matches reference for normalize=False callers.
  • Compile time is reasonable (< 5 min) — if not, document the hang and fall back.

Priority 3: NVFP4-1.1 — Fuse FP4 quant into MoE SwiGLU epilogue

Independent of FMHA. Can run in parallel with Priority 2. Biggest bandwidth win in the codebase.

Current:

padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM
                                  ↓
                          quantize_activation_nvfp4 (separate kernel)
                                  ↓
                         padded_activated_fp4 → L2 GEMM

Target:

padded_x_fp4 → L1 GEMM → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM

The SwiGLU + clamp result already lives in registers at tRS_rC.store(acc_vec_bf16) (line 2207 of fused_swiglu.py). That's the slot for amax + FP4 pack.

Per-microblock amax (16 contiguous elements):

  1. shfl_xor butterfly reduction across the 4 threads holding the 16 elements.
  2. FP8 E4M3 scale = amax / 6 (FP4 e2m1 max).
  3. Per-element FP4 pack: sign_bit << 3 | (clamped_val / scale).to(uint3). Two elements → one byte.
  4. 16 packed nibbles → 64-bit word → SMEM stage → TMA store.
  5. FP8 scale → separate scale-factor SMEM stage → TMA store to the L2 SFA buffer.

Done when: padded_activated_fp4 and padded_activated_x_sf scratch buffers go away, quantize_activation_nvfp4 between L1 and L2 disappears, L1→L2 cosine matches reference.


Priority 4: D2 multi-CTA grid

Depends on Priority 2. Currently per-head Python launch dispatches 128 kernels per Pro decode step. Multi-CTA grid collapses that to 1.

Grid: (num_M_tiles, num_query_heads, batch) — at decode T=1: (1, 128, batch).

Q tensor layout: Option 1 — (batch, n_h, T, head_dim) with head as a TMA mode. Matches CUTLASS reference, allows per-head LSE output, generalizes to GQA later.

MQA K/V sharing: start with independent K/V loads per CTA (each loads its own copy). At decode hd=512, K/V per CTA is ~128 KB; 128 CTAs × 128 KB = 16 MB total, comfortably within HBM bandwidth. Cluster-wide sharing via cluster_shape_mn=(1, num_query_heads, 1) is a future optimization once profiling shows it matters.

Done when: n_h=128, batch=4, T=1 at hd=512 produces correct output with single launch, per-head LSE writes to mLSE[batch, head, m_row] correctly.


Priority 5: Stage E — Production extraction

D5 is complete. Wrap the kernel in a proper interface.

Step What Status
E1 File placement: dsv4/kernels/attention/fmha.py Done
E2 Constructor signature (head_dim, num_query_heads, sliding_window, top_k, sink/causal flags, dtypes) ⚠️ Partial — needs cleanup
E3 Call signature: q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o, stream ⚠️ Needs sink_bias / row_sums integration
E4 Kernel cache + warmup, keyed on (head_dim, num_query_heads, top_k, n_comp, apply_sink_bias, is_causal, ...) TODO
E5 torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",)) TODO
E6 Reference parity test against FP32 oracle in dsv4/reference/attention.py TODO
E7 Cleanup: delete debug test files, keep only tests/unit/test_fmha_kernel.py TODO

Block table, paged KV, FP8 dequant, inv_scale — all handled upstream by the indexer + gather chain. FMHA sees a dense BF16 [T, top_k, head_dim] tile.


Priority 6: NVFP4-1.2 — Fuse FP4 quant into FMHA output → wo_a path

Depends on Priority 2 (uses the register slot in the new final epilogue).

Currently: FMHA emits BF16 → inverse RoPE → BF16 GMEM → wo_a quantizes to FP4.

Target: register slot in FMHA's new final epilogue does O / row_sum and inverse RoPE rotation and per-microblock amax and FP4 pack. wo_a reads FP4 directly with no GMEM materialization.

Same pattern as Priority 3, different home (FMHA final epilogue, not MoE epilogue).


Priority 7: NVFP4-2 — FP4 KV pipeline depth in FMHA

Depends on Priority 2 being solid at BF16 KV first. FP4 KV shrinks tiles ~4×; same SMEM budget supports more pipeline stages.

KV dtype Tile size (hd=512) Stages fitting 192 KB
BF16 128 KB 2
FP8 64 KB 4
FP4 ~36 KB 6

At 1M-context decode where KV reads dominate, deeper pipelines hide more TMA latency.

Implementation:

  • TMA loads FP4 NoPE dims (e2m1_x2 packed) to SMEM slot 0.
  • TMA loads BF16 RoPE dims to SMEM slot 1.
  • TMA loads FP8 scale factors to SMEM slot 2.
  • SMEM dequant FP4 → BF16 vectorized (* FP8_scale, 16-element microblocks).
  • Concatenate [NoPE, RoPE] in SMEM.
  • MMA reads contiguous BF16 from SMEM.

Test: FP4+BF16 split input → identical output to pure BF16 input. Dequant must be transparent.


Priority 8 (conditional on P1 profile): Per-kt O rescale fix

Only justified if Priority 1 profiling shows Python KV merge overhead > 5% of decode latency. Otherwise defer indefinitely — the current correct workaround ships.

There are three possible paths. None is a small change.

Path A: CUTLASS atom replication

The CUTLASS C++ Blackwell FMHA does a TMEM round-trip for O rescale using SM100_TMEM_LOAD_32dp32b_4x_atom and SM100_TMEM_STORE_32dp32b_4x_atom, which are paired by hardware design. So in principle TMEM round-trip is possible. The question is whether CuTeDSL Python exposes the specific atom variants and layout configuration CUTLASS uses, and whether they can be paired correctly through make_tmem_copy.

Steps:

  • Read /root/cutlass/.../blackwell/kernel/attention/fmha/fmha.py (or equivalent C++ reference) and document the exact atom + repetition + tensor layout used for correction_rescale.
  • Enumerate what CuTeDSL Python exposes: dir(tcgen05.copy), available LdNxNbOp / StNxNbOp variants, what Repetition(N) controls.
  • Identify the difference between current Ld32x32bOp(Repetition(16)) + St32x32bOp(Repetition(16)) and whatever CUTLASS uses.
  • Build a minimal NO-OP round-trip test (load O, store back unchanged) with the candidate atom configuration. Verify cos = 1.0.
  • If NO-OP passes, retest with * acc_scale modification.

Risk: CuTeDSL Python may not expose the necessary atom variants, or may not allow the layout configuration CUTLASS uses. In that case, escalate to Path B or C.

Effort if it works: 24 days investigation + 12 days porting.

Path B: O accumulator in registers, manual PV

Restructure FMHA so the PV accumulator is register-resident, not TMEM-resident. Each kt: read V from SMEM, read P from TMEM/SMEM, compute one-shot PV (no accumulate) writing to a temporary TMEM region, then load to registers and add to register-resident running O (with acc_scale applied to the running O before the add).

Implications:

  • Register pressure is severe at hd=512. 512 FP32 per row × 1 row per thread × 128 threads = 128 KB of registers just for O. Possible but tight.
  • PV without TMEM accumulate is a non-standard MMA usage. May need to use a smaller PV tile and accumulate in registers across sub-tiles.
  • Loses the natural MMA-accumulator pipeline overlap.

Effort: 12 weeks. High risk of regressing hd=64/128/256 paths during the refactor.

Path C: O accumulator in SMEM

Variant of Path B with O in SMEM instead of registers. PV writes to a temporary TMEM region, gets loaded to registers, applies acc_scale * existing_SMEM_O + new_PV, stores back to SMEM. Final epilogue reads from SMEM.

Implications:

  • SMEM budget tightens significantly (need O in SMEM = ~64 KB at hd=512).
  • Adds SMEM read/write pressure on every kt.
  • May require dropping kv_stage to 1 across the board.

Effort: ~1 week. Lower risk than Path B but bigger perf impact (more SMEM traffic).

  1. Try Path A first — least invasive, may just work with the right atom config.
  2. If Path A confirmed impossible in CuTeDSL Python, try Path C (SMEM-resident O) before Path B (register-resident O). SMEM gives more headroom at hd=512.
  3. Path B only if both fail and the perf gap is truly critical.

Priority 9: hd=512 single-kernel fix

Currently blocked. MLIR optimizer hangs > 3 hours on the hd=512 kernel. Tracer completes in 0.8s — kernel is structurally correct.

Decode works via head-packed M with pv_n_tile=128 and n_k_sub_tiles=2, so this is only a prefill efficiency issue. Lower priority than the chain above.

Options:

  1. Pre-compile cubin offline (accept 12 hour compile, cache result).
  2. Add no-softmax mode emitting raw S to GMEM; call twice for k_sub=0/1, accumulate in Python, softmax once externally.
  3. Write hd=512 path in raw CUTLASS C++. Bypasses CuTeDSL MLIR entirely. Most realistic if NVIDIA can't fix the optimizer.
  4. Report CuTeDSL MLIR optimizer bug to NVIDIA.

Priority 10: Indexer FP4 tensor-core scoring (Stage F)

Paper §5.2.1: "the QK path in the indexer of CSA, where QK activations are cached, loaded, and multiplied entirely in FP4."

Current indexer (dsv4/kernels/cuda/indexer_score_topk.cu): scalar FP32 dot products, no tensor cores, spinlock-protected shared-memory heap. Single largest perf gap in the codebase. At 1M-context decode it scores ~250K compressed entries per query token — the spinlock heap will not scale to top_k=1024.

Target: port DeepGEMM fp8_paged_mqa_logits to FP4 inputs with tcgen05.mma.kind=mxf4nvf4. Plus per-warp partial top-k merged with a final reduction tree (or radix-select). Plus FP32→BF16 score quantization (paper claims 2× speedup on top-k selector at 99.7% recall).

Scope: 23 weeks. Stage F. Do not start until the FP4 epilogue patterns from Priorities 3 and 6 are established — they inform the indexer's FP4 load + score paths.


Priority 1 (PROFILE) ──► gates Priority 8
                        │
Priority 2 (one-way    ─┼─► unblocks Priority 4 (multi-CTA)
final epilogue)         │   unblocks Priority 6 (FP4 fuse in FMHA)
                        │
Priority 3 (NVFP4-1.1) ─┴── parallel, independent
                        
[verify hd regressions]
        │
        ▼
Priority 4 (D2 multi-CTA grid)
        │
        ▼
Priority 5 (Stage E production extraction)
        │
        ▼
Priority 6 (NVFP4-1.2 FP4 fuse in FMHA output)
        │
        ▼
Priority 7 (NVFP4-2 FP4 KV pipeline)
        │
        ▼
Priority 8 (per-kt rescale fix — ONLY if P1 says it matters)
        │
        ▼
Priority 9 (hd=512 fix — only if prefill efficiency demands)
        │
        ▼
Priority 10 (indexer FP4 tensor-core scoring) — Stage F

Key change from the previous version: Priority 1 is now a profiling task, not an engineering task. Priority 2 is scoped honestly — it's the one-way path only, and it does not fix the per-kt rescale. The per-kt rescale is Priority 8, conditional, and three paths deep because there is no easy fix.


Speculative — beyond what the V4 paper validated

Listed for completeness. Do not implement without explicit sign-off.

  1. NVFP4 compressed KV NoPE dims. Paper validated FP8; FP4 would halve cache again. Risk: compounds quantization noise on already-lossy compressed KV.
  2. MXFP4 vs NVFP4 for indexer scoring. Not validated for indexer specifically.
  3. NVFP4 for full attention Q×K^T GEMM. Closed. Cos 0.86 vs FP32 in earlier tests. Attention stays BF16/FP32.
  4. Per-token FP8 activation scaling in FMHA. Not validated. Out of scope.
  5. 2:4 structured sparsity on FP4 expert weights. V4 not trained with structured sparsity. Off the table for the released checkpoint.
  6. NVFP4 LM head + MTP head. Big VRAM win (~1.4 GB saved on Pro). Modest quality risk on rare-token logits. Test against held-out eval before shipping.

Key numbers to remember

Config n_h top_k s_k decode n_kv_tiles Multi-tile?
Flash decode 64 512 640 5 YES
Pro decode 128 1024 1152 9 YES
Current single-tile test 1 128 1 NO

Production decode needs the multi-tile path. Today's Python KV merge ships correct results at the cost of 59 launches per step. Whether that cost matters is what Priority 1 measures.