Files
nvfp4-megamoe-kernel/ROADMAP.md
biondizzle 43f0b5d1e8 D1.5: Fix O rescale with paired atoms (incremental approach)
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.

Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.

Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
2026-05-26 19:34:26 +00:00

14 KiB
Raw Blame History

ROADMAP

Living document. Current state, active blockers, priority order, and what to build next. Architecture and lessons live in README.md — this file is for "what now."

Last updated: 2026-05-26


Current status

Working

Component hd n cos Status
FMHA TMEM-P 64 128 0.999998
FMHA TMEM-P / SMEM-P 128 128 0.999997
FMHA TMEM-P 256 128 0.999998
FMHA multi-tile (Python KV merge) 64 up to 1024 0.999998 Workaround
D3 SWA length mask (in-kernel) 128 128 0.999996
D4 causal mask on SWA (in-kernel) 128 128 0.999996
D5c sink merge single-tile 64 128 0.999996
D5c sink merge multi-tile (Python KV merge) 64 256 0.999996
Per-head multi-head launch 64 128 0.999995 n_h=1128
MoE fused SwiGLU (NVFP4) matches ref Clamping in kernel
Dense router (sqrt-softplus) matches ref
Hash router matches ref
use_2cta_instrs conditional 1.71.9× speedup M≥256 prefill
NVFP4 primitives E4M3 SF / mxf4nvf4 / 16-elem Verified

Known blockers

Blocker Impact Workaround Fix path
D1.5 TMEM round-trip corruption Hand-built atoms produce 3% error on NO-OP round-trip; blocks in-kernel multi-tile O rescale and in-kernel normalize Emit un-normalized O + LSE; Python KV merge for s_k>128 Priority 1: correction-epilog rewrite (sketch below)
hd=512 MLIR backend hang Cannot compile hd=512 kernel (>3hr optimizer time, structurally correct) Run hd=512 via head-packed M with hd≤256 chunks; ship without hd=512 if needed for D2 Pre-compile cubin / raw CUTLASS C++ / report NVIDIA bug
D2 multi-CTA grid (flat_divide + epilogue_tma_store) Per-head Python launch wastes 128 launches per decode step at Pro Per-head launch (works, just slow) Unblocked by correction-epilog rewrite (uses flat_divide + tma_partition like MoE does)

Priority 1: Correction epilog rewrite (unblocks D1.5 + a chain of follow-ons)

Why this first: Every downstream item needs the kernel to have a register-level slot in the epilogue for modification. The current epilogue_tma_store path with hand-built atoms doesn't have one. The correction-epilog pattern does.

The pattern is already in the codebasedsv4/kernels/gemm/fused_swiglu.py uses it for the MoE SwiGLU epilogue (lines 2021, 20642229). Library helpers, paired atoms, one-way TMEM → registers → SMEM → GMEM. SwiGLU + clamping math sits between the t2r and r2s copies. That's the exact slot FMHA needs.

What changes:

Replace the FMHA epilogue (dsv4/kernels/attention/fmha.py lines 549597 — the epilogue_tma_store call) with:

  1. Setup (run once per kernel, outside the kt loop):

    • tCtO_transformed = utils.gemm.sm100.transform_partitioned_tensor_layout(tOtO0)
    • tCgC_transformed = utils.gemm.sm100.transform_partitioned_tensor_layout(tCgC)
    • tiled_copy_t2r, tTR_tO_base, tTR_rO = utils.gemm.sm100.epilogue_tmem_copy_and_partition(...)
    • tiled_copy_r2s, tRS_rC, tRS_sC = utils.gemm.sm100.epilogue_smem_copy_and_partition(...)
    • TMA partition for C via flat_divide(tCgC_transformed, epi_tile) + cpasync.tma_partition(...)
  2. Final epilogue (replaces the round-trip normalize):

    • Subtile loop, in each subtile: cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO) → multiply by inv_row_sum in registers if normalize=True → cast to BF16 → cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[...]) → TMA SMEM → GMEM.
  3. Per-kt O rescale (replaces the broken hand-built round-trip on lines 524544):

    • Inside the kt loop, when kt > 0: same t2r → multiply by acc_scale in registers → store back via the paired atom (tiled_copy_t2r.retile_to_S() or equivalent — verify exact API on B200).

What unblocks:

  • D1.5 issue 1 (round-trip corruption): gone.
  • D1.5 issue 2 (per-kt rescale): gone.
  • In-kernel multi-tile attention (single launch for s_k=1152, not 9).
  • NVFP4-1.2 (fuse FP4 quant into FMHA output → wo_a path): the register slot is where amax + FP4 pack go.
  • D2 multi-CTA grid: flat_divide + tma_partition path is the same one MoE uses successfully. The flat_divide vs local_tile mismatch resolves.

Caveats to print and verify on B200:

  • Exact CUTLASS helper API for the "store back to TMEM" direction (retile_to_S form vs separate helper vs same-base-tensor pattern).
  • Whether transform_partitioned_tensor_layout accepts tOtO0 (TMEM iterator with offset) or needs a fresh tensor built at tmem_ptr + self.tmem_o0_offset with tCtO_fake.layout.
  • Whether tma_partition inside if warp_idx < self.mma_warp_id works in this kernel's region tree. The MoE kernel does it; if FMHA hits "weakly congruent," hoist the partition call above the warp branch.

Done when:

  • hd=64/128/256 regression cos ≥ 0.999998 holds with normalize=True and normalize=False.
  • New multi-tile s_k=256 test with kt > 0 rescale gives cos ≥ 0.999998 (not the current 0.997 Python-merge workaround, the real in-kernel rescale).
  • Existing Python KV merge tests continue to pass (test_d15_multi_kv.py).

Priority 2: Stage E — Production extraction

D5 is complete. The kernel works. Wrap it in a proper interface.

Step What Status
E1 File placement: dsv4/kernels/attention/fmha.py Done
E2 Constructor signature (head_dim, num_query_heads, sliding_window, top_k, sink/causal flags, dtypes) ⚠️ Partial — needs cleanup
E3 Call signature: q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o, stream ⚠️ Needs sink_bias / row_sums integration
E4 Kernel cache + warmup, keyed on (head_dim, num_query_heads, top_k, n_comp, apply_sink_bias, is_causal, ...) TODO
E5 torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",)) TODO
E6 Reference parity test against FP32 oracle in dsv4/reference/attention.py TODO
E7 Cleanup: delete debug test files, keep only tests/unit/test_fmha_kernel.py TODO

Notably absent from the call signature: block_table, paged KV, inv_scale, FP8 dequant. All handled upstream by the indexer + gather kernel chain. FMHA sees a dense BF16 [T, top_k, head_dim] tile.


Priority 3: NVFP4-1.1 — Fuse FP4 quant into MoE SwiGLU epilogue

Independent of FMHA. Biggest bandwidth win in the codebase. Can run in parallel with Priority 1.

Current:

padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM
                                  ↓
                          quantize_activation_nvfp4 (separate kernel)
                                  ↓
                         padded_activated_fp4 → L2 GEMM

Target:

padded_x_fp4 → L1 GEMM → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM

The SwiGLU + clamp result already lives in registers at tRS_rC.store(acc_vec_bf16) (line 2207 of fused_swiglu.py). That's the slot for amax + FP4 pack.

Per-microblock amax (16 contiguous elements):

  1. shfl_xor butterfly reduction across the 4 threads that hold the 16 elements.
  2. FP8 E4M3 scale = amax / 6 (FP4 e2m1 max).
  3. Per-element FP4 pack: sign bit << 3 | (clamped val / scale).to(uint3). Two elements → one byte.
  4. 16 packed nibbles → 64-bit word → SMEM stage → TMA store.
  5. FP8 scale → separate scale-factor SMEM stage → TMA store to the L2 SFA buffer.

Subtlety: NVFP4 microblock = 16 elements. Port the same 16-element logic from dsv4/ops/quantize.py. Don't accidentally use the 32-element MXFP4 block.

Done when:

  • padded_activated_fp4 and padded_activated_x_sf scratch buffers go away.
  • quantize_activation_nvfp4 between L1 and L2 disappears.
  • L1 → L2 cosine matches reference (no regression from BF16 intermediate).
  • L2 GEMM reads FP4 scales produced by L1 epilogue.

Priority 4: D2 multi-CTA grid

Currently per-head Python launch (works, cos 0.999995, but 128 launches per decode step at Pro).

Multi-CTA grid is unblocked by Priority 1 — the flat_divide + tma_partition path becomes available once the epilogue uses the MoE pattern.

Grid: (num_M_tiles, num_query_heads, batch) — at decode T=1: (1, 128, batch).

MQA K/V sharing: start with independent K/V loads per CTA (each CTA loads its own copy). At decode hd=512, K/V per CTA is ~128 KB; 128 CTAs × 128 KB = 16 MB, well within HBM bandwidth. Cluster-wide sharing via cluster_shape_mn=(1, num_query_heads, 1) is a future optimization once profiling shows it matters.

Q tensor layout: Option 1 — (batch, n_h, T, head_dim) with head as a TMA mode (matches CUTLASS reference and allows per-head LSE output). Picked over Option 2 (heads packed into M) because it generalizes better to GQA later.

Done when:

  • n_h=128, batch=4, T=1 at hd=512 produces correct output with single launch.
  • Per-head LSE writes to correct mLSE[batch, head, m_row] position.

Priority 5: NVFP4-1.2 — Fuse FP4 quant into FMHA output → wo_a path

Depends on Priority 1 (correction epilog gives the register slot).

Currently: FMHA emits BF16 → inverse RoPE produces BF16 → wo_a quantizes to FP4.

Target: register slot in FMHA epilogue does the divide-by-row_sum and inverse RoPE rotation and per-microblock amax + FP4 pack. wo_a reads FP4 directly.

Same pattern as Priority 3. Different home (FMHA epilogue, not MoE epilogue).


Priority 6: NVFP4-2 — FP4 KV pipeline depth in FMHA

Depends on Priority 1 being solid at BF16 KV first.

FP4 KV shrinks tiles ~4×; same SMEM budget supports more pipeline stages.

KV dtype Tile size (hd=512) Stages that fit (192 KB budget)
BF16 128 KB 2
FP8 64 KB 4
FP4 ~36 KB 6

At 1M-context decode where KV reads dominate, deeper pipelines hide more TMA latency.

Implementation:

  • TMA loads FP4 NoPE dims (packed e2m1_x2) to SMEM slot 0.
  • TMA loads BF16 RoPE dims to SMEM slot 1.
  • TMA loads FP8 scale factors to SMEM slot 2.
  • SMEM dequant FP4 → BF16 in vectorized form (* FP8_scale, 16-element microblocks).
  • Concatenate [NoPE, RoPE] in SMEM.
  • MMA reads contiguous BF16 from SMEM.

Test: FP4+BF16 split input → identical output to pure BF16 input (dequant must be transparent).


Priority 7: hd=512 fix

Blocked. Per Priority 4, multi-CTA grid + head-packed M means decode at hd=512 can route through pv_n_tile=128 and n_k_sub_tiles=2, which compiles fine for hd=256. The hd=512 single-kernel compile is the missing piece for prefill efficiency, not correctness.

Options:

  1. Pre-compile hd=512 cubin offline (accept 12 hour compile if MLIR ever finishes — uncertain).
  2. Add no-softmax mode emitting raw S to GMEM, call twice for k_sub=0/1, accumulate in Python, softmax once. Two launches but no MLIR hang.
  3. Write hd=512 path in raw CUTLASS C++. Bypasses CuTeDSL MLIR entirely. Most realistic if NVIDIA can't fix the optimizer.
  4. Report CuTeDSL MLIR optimizer bug to NVIDIA.

Lower priority than the chain above — at decode T=1, n_h=128, hd=512 the head-packed approach already works without needing a single hd=512 kernel.


Priority 8: Indexer FP4 tensor-core scoring (Stage F)

Paper §5.2.1: "the QK path in the indexer of CSA, where QK activations are cached, loaded, and multiplied entirely in FP4."

Current indexer (dsv4/kernels/cuda/indexer_score_topk.cu): scalar FP32 dot products, no tensor cores, spinlock-protected shared-memory heap. Single largest perf gap in the codebase. At 1M-context decode the indexer scores ~250K compressed entries per query token — the spinlock heap will not scale to top_k=1024.

Target: port DeepGEMM fp8_paged_mqa_logits to FP4 inputs with tcgen05.mma.kind=mxf4nvf4. Plus per-warp partial top-k merged with a final reduction tree (or radix-select). Plus FP32→BF16 score quantization per paper (2× speedup on top-k selector, 99.7% recall).

Scope: 23 weeks. Track for Stage F. Do not start until the FP4 epilogue patterns from Priorities 3 and 5 are established — they'll inform the indexer's FP4 load + score paths.


Now ─┬─ Priority 1 (correction epilog rewrite)
     │     │
     │     └─→ unblocks D1.5, D2 multi-CTA, NVFP4-1.2
     │
     ├─ Priority 3 (NVFP4-1.1 fuse FP4 in SwiGLU) ← parallel, independent
     │
     ↓
     Verify hd=64/128/256 regressions hold
     │
     ↓
     Priority 2 (Stage E production extraction)
     │
     ↓
     Priority 4 (D2 multi-CTA grid)
     │
     ↓
     Priority 5 (NVFP4-1.2 fuse FP4 in FMHA output)
     │
     ↓
     Priority 6 (NVFP4-2 FP4 KV pipeline)
     │
     ↓
     Priority 7 (hd=512 fix — only if prefill efficiency demands it)
     │
     ↓
     Priority 8 (indexer FP4 tensor-core scoring) — Stage F

Priority 3 has no dependency on Priorities 1 or 2 and can run on a parallel branch.


Speculative — beyond what the V4 paper validated

Listed for completeness. Do not implement without explicit sign-off.

  1. NVFP4 compressed KV NOPE dims (paper validated FP8 for compressed KV; FP4 would halve cache again). Risk: compounds quantization noise on already-lossy compressed KV.
  2. MXFP4 vs NVFP4 for indexer scoring — not validated for indexer specifically.
  3. NVFP4 for full attention Q×K^T GEMM — closed. Cos 0.86 vs FP32 in earlier tests. Attention stays BF16/FP32.
  4. Per-token FP8 activation scaling in FMHA — not validated. Out of scope.
  5. 2:4 structured sparsity on FP4 expert weights — V4 not trained with structured sparsity. Off the table for the released checkpoint.
  6. NVFP4 LM head + MTP head — big VRAM win (~1.4 GB saved on Pro). Modest quality risk on rare-token logits. Test against held-out eval before shipping.

Key numbers to remember

Config n_h top_k s_k decode n_kv_tiles Multi-tile?
Flash decode 64 512 640 5 YES
Pro decode 128 1024 1152 9 YES
Current single-tile test 1 128 1 NO

Production decode needs the multi-tile path (Priority 1) working in-kernel. Today's Python KV merge ships correct results at the cost of 59 launches per step.