Files
nvfp4-megamoe-kernel/ROADMAP.md
biondizzle bf2c7c8bb8 D1.5: Implement in-kernel O rescale via CUTLASS correction_rescale pattern
- Both load and store atoms built from SAME tOtO_i (composition-tiled)
- Same Repetition(corr_tile_size=16) for both copies
- pv_done_bar synchronization between MMA and softmax warps
- acc_scale computed per kt iteration, used to rescale O in TMEM
- const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128
- New test: test_d15_in_kernel_rescale.py (s_k=128/256/384)
- Minimal roundtrip test: test_tmem_roundtrip_minimal.py
2026-05-26 20:26:06 +00:00

343 lines
20 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ROADMAP
Living document. Current state, active blockers, priority order, and what to build next. Architecture and lessons live in README.md — this file is for "what now."
**Last updated:** 2026-05-26 (revised after correcting D1.5 fix-path analysis)
---
## Current status
### Working
| Component | hd | n | cos | Status |
|---|---:|---:|---:|---|
| FMHA TMEM-P | 64 | 128 | 0.999998 | ✅ |
| FMHA TMEM-P / SMEM-P | 128 | 128 | 0.999997 | ✅ |
| FMHA TMEM-P | 256 | 128 | 0.999998 | ✅ |
| FMHA multi-tile (Python KV merge) | 64 | up to 1024 | 0.999998 | ✅ Workaround — see below |
| D3 SWA length mask (in-kernel) | 128 | 128 | 0.999996 | ✅ |
| D4 causal mask on SWA (in-kernel) | 128 | 128 | 0.999996 | ✅ |
| D5c sink merge single-tile | 64 | 128 | 0.999996 | ✅ |
| D5c sink merge multi-tile (Python KV merge) | 64 | 256 | 0.999996 | ✅ |
| Per-head multi-head launch | 64 | 128 | 0.999995 | ✅ n_h=1128 |
| MoE fused SwiGLU (NVFP4) | — | — | matches ref | ✅ Clamping in kernel |
| Dense router (sqrt-softplus) | — | — | matches ref | ✅ |
| Hash router | — | — | matches ref | ✅ |
| `use_2cta_instrs` conditional | — | — | 1.71.9× speedup | ✅ M≥256 prefill |
| NVFP4 primitives | — | — | E4M3 SF / mxf4nvf4 / 16-elem | ✅ Verified |
### Known blockers
| Blocker | Impact | Status |
|---|---|---|
| **Per-kt O rescale in TMEM (D1.5)** | Multi-tile attention requires 59 kernel launches per decode step instead of 1 (Python KV merge) | Workaround is correct (cos 0.999998). Whether to fix depends on profiling — see Priority 1. |
| **TMEM final-normalize round-trip** | Cannot do in-kernel `O /= row_sum` cleanly | **Already worked around** — emit un-normalized O + LSE, external divide is exact. Not a blocker for shipping. |
| **`epilogue_tma_store` blocks D2 multi-CTA + NVFP4-1.2** | Per-head Python launch wastes 128 launches per Pro decode step; FMHA output forces BF16 GMEM round-trip before wo_a | Unblocked by Priority 2 (the one-way final-epilogue rewrite — which is the **only** part of the MoE pattern that legitimately ports to FMHA). |
| **hd=512 MLIR backend hang** | Cannot compile single-kernel hd=512 (>3hr optimizer time, structurally correct) | Decode works via head-packed M with hd≤256 chunks. Single-kernel hd=512 only needed for prefill efficiency. Low priority. |
---
## What the MoE epilogue pattern actually buys us (and what it doesn't)
This deserves stating explicitly because the previous version of this document had it wrong.
**The MoE pattern is one-way only.** `dsv4/kernels/gemm/fused_swiglu.py` uses `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` to construct a TMEM → registers → SMEM → GMEM pipeline. There is **no** corresponding store-back-to-TMEM helper, and no inverse pairing of the t2r atom. The MoE epilogue runs *once*, *after* all MMA K-tiles are accumulated. It never needs to mutate the TMEM accumulator mid-loop.
**FMHA's per-kt O rescale is structurally different.** PV uses `tcgen05.mma` with `ACCUMULATE=True` across the kt loop. The accumulator must live in TMEM because that's where MMA reads it and writes it. When row_max changes between kt iterations, the running O accumulator has to be multiplied by `acc_scale` *in TMEM* before the next PV — load to registers, multiply, store back. The store-back is the part that's broken: `Ld32x32bOp` and `St32x32bOp` built as separate atoms have hardware column mappings that don't match, producing ~3% corruption even on NO-OP round-trip. No software layout transformation in CuTeDSL Python has, so far, made them pair correctly.
**Therefore:** porting the MoE pattern to FMHA **only fixes the one-way paths**:
1. The final epilogue (after the last kt, when O is being written to GMEM for good).
2. Any FP4 amax + pack fusion into that final epilogue.
It does **not** fix the per-kt rescale. That is a separate, harder problem with three possible paths laid out below.
---
## Priority 1: Profile production decode to determine if the rescale fix is needed
Before investing days in fixing the per-kt rescale, measure whether the 59 launch overhead from Python KV merge is actually a bottleneck.
At Pro decode (s_k=1152, n_kv_tiles=9), Python merge dispatches 9 kernels per step. Conservative launch overhead ~50 μs per kernel ≈ 450 μs/step in launch overhead alone. If a full decode step (all 61 layers, MoE, embedding, sampler) takes ~30 ms, that's ~1.5% of latency. If it takes ~10 ms, it's ~4.5%. Whether that's worth a 12 week refactor depends on the actual measurement.
**Action:**
- [ ] Profile Pro decode at s_k=1152 with current Python KV merge. Measure: total step latency, launch overhead from FMHA dispatches, FMHA compute time per launch.
- [ ] Measure CPU dispatch overhead on the host (Python loop + kernel cache lookup).
- [ ] Decision rule: if Python merge overhead is < 5% of total decode latency, defer Priority 8 indefinitely. Ship Python merge as production path.
**Done when:** there's a profiled number that justifies (or doesn't) the engineering investment in Priority 8.
---
## Priority 2: One-way final-epilogue rewrite
**What:** Replace the `utils.gemm.sm100.epilogue_tma_store(...)` call at the end of FMHA (`dsv4/kernels/attention/fmha.py` lines 565577) with the MoE-style explicit pipeline:
```
transform_partitioned_tensor_layout → epilogue_tmem_copy_and_partition →
[register slot — optional normalize/cast/FP4 pack] →
epilogue_smem_copy_and_partition → flat_divide → cpasync.tma_partition → TMA store
```
This is **strictly one-way**. The kt-loop rescale code stays exactly as it is (using the broken hand-built atoms — see Priority 8 for the fix path).
**What this enables:**
1. **Optional in-kernel normalize.** Adds an `if const_expr(self.normalize):` block at the register slot to multiply by `inv_row_sum`. Currently external code does the divide on the un-normalized output. Tiny perf win, not the main reason to do this.
2. **Unblocks NVFP4-1.2** (Priority 6) — gives a register-level modification slot in the FMHA output path where FP4 amax + pack can live, eliminating the BF16 GMEM materialization between FMHA and wo_a.
3. **Likely unblocks D2 multi-CTA grid** (Priority 4) — the current `epilogue_tma_store` is what couldn't accept the `flat_divide`-based GMEM coordinate system. Switching to the explicit `cpasync.tma_partition(tma_c, ..., cute.flat_divide(tCgC_transformed, epi_tile))` path puts FMHA on the same TMA pattern MoE uses successfully, which should accept multi-CTA block coordinates.
**Caveats to verify on B200 before assuming this works:**
- Whether `tma_partition` survives inside the `if warp_idx < self.mma_warp_id` block. MoE calls it from inside its epilogue warp's `if`, but that has not been tested in FMHA's region tree. Previous attempts at the full pattern triggered 20+ minute MLIR compile times before reaching a verdict.
- Whether `transform_partitioned_tensor_layout` accepts FMHA's `tOtO0` (TMEM iterator with offset) directly, or whether it needs a fresh tensor built at `tmem_ptr + self.tmem_o0_offset` with `tCtO_fake.layout`.
- The `epilogue_tmem_copy_and_partition` helper signature on the current CuTeDSL version — print on B200 before coding.
**Failure mode to watch for:** if compile hangs as it did previously, this rewrite is genuinely blocked and the chain of follow-ons (Priorities 4 and 6) need alternative paths.
**Effort:** 12 days if the helpers cooperate. Multiple days if the MLIR hang reappears.
**Done when:**
- hd=64/128/256 regression cos ≥ 0.999998 holds with both `normalize=True` and `normalize=False` paths.
- LSE output still matches reference for `normalize=False` callers.
- Compile time is reasonable (< 5 min) — if not, document the hang and fall back.
---
## Priority 3: NVFP4-1.1 — Fuse FP4 quant into MoE SwiGLU epilogue
**Independent of FMHA. Can run in parallel with Priority 2.** Biggest bandwidth win in the codebase.
Current:
```
padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM
quantize_activation_nvfp4 (separate kernel)
padded_activated_fp4 → L2 GEMM
```
Target:
```
padded_x_fp4 → L1 GEMM → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM
```
The SwiGLU + clamp result already lives in registers at `tRS_rC.store(acc_vec_bf16)` (line 2207 of `fused_swiglu.py`). That's the slot for amax + FP4 pack.
**Per-microblock amax (16 contiguous elements):**
1. `shfl_xor` butterfly reduction across the 4 threads holding the 16 elements.
2. FP8 E4M3 scale = amax / 6 (FP4 e2m1 max).
3. Per-element FP4 pack: `sign_bit << 3 | (clamped_val / scale).to(uint3)`. Two elements → one byte.
4. 16 packed nibbles → 64-bit word → SMEM stage → TMA store.
5. FP8 scale → separate scale-factor SMEM stage → TMA store to the L2 SFA buffer.
**Done when:** `padded_activated_fp4` and `padded_activated_x_sf` scratch buffers go away, `quantize_activation_nvfp4` between L1 and L2 disappears, L1→L2 cosine matches reference.
---
## Priority 4: D2 multi-CTA grid
**Depends on Priority 2.** Currently per-head Python launch dispatches 128 kernels per Pro decode step. Multi-CTA grid collapses that to 1.
**Grid:** `(num_M_tiles, num_query_heads, batch)` — at decode T=1: `(1, 128, batch)`.
**Q tensor layout:** Option 1 — `(batch, n_h, T, head_dim)` with head as a TMA mode. Matches CUTLASS reference, allows per-head LSE output, generalizes to GQA later.
**MQA K/V sharing:** start with independent K/V loads per CTA (each loads its own copy). At decode hd=512, K/V per CTA is ~128 KB; 128 CTAs × 128 KB = 16 MB total, comfortably within HBM bandwidth. Cluster-wide sharing via `cluster_shape_mn=(1, num_query_heads, 1)` is a future optimization once profiling shows it matters.
**Done when:** `n_h=128, batch=4, T=1` at hd=512 produces correct output with single launch, per-head LSE writes to `mLSE[batch, head, m_row]` correctly.
---
## Priority 5: Stage E — Production extraction
D5 is complete. Wrap the kernel in a proper interface.
| Step | What | Status |
|---|---|---|
| E1 | File placement: `dsv4/kernels/attention/fmha.py` | ✅ Done |
| E2 | Constructor signature (`head_dim`, `num_query_heads`, `sliding_window`, `top_k`, sink/causal flags, dtypes) | ⚠️ Partial — needs cleanup |
| E3 | Call signature: `q`, `compressed_kv`, `swa_kv`, `swa_lens`, `sink_logits`, `request_ids`, `o`, `stream` | ⚠️ Needs sink_bias / row_sums integration |
| E4 | Kernel cache + warmup, keyed on `(head_dim, num_query_heads, top_k, n_comp, apply_sink_bias, is_causal, ...)` | TODO |
| E5 | `torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",))` | TODO |
| E6 | Reference parity test against FP32 oracle in `dsv4/reference/attention.py` | TODO |
| E7 | Cleanup: delete debug test files, keep only `tests/unit/test_fmha_kernel.py` | TODO |
Block table, paged KV, FP8 dequant, inv_scale — all handled upstream by the indexer + gather chain. FMHA sees a dense BF16 `[T, top_k, head_dim]` tile.
---
## Priority 6: NVFP4-1.2 — Fuse FP4 quant into FMHA output → wo_a path
**Depends on Priority 2** (uses the register slot in the new final epilogue).
Currently: FMHA emits BF16 → inverse RoPE → BF16 GMEM → wo_a quantizes to FP4.
Target: register slot in FMHA's new final epilogue does `O / row_sum` *and* inverse RoPE rotation *and* per-microblock amax *and* FP4 pack. wo_a reads FP4 directly with no GMEM materialization.
Same pattern as Priority 3, different home (FMHA final epilogue, not MoE epilogue).
---
## Priority 7: NVFP4-2 — FP4 KV pipeline depth in FMHA
**Depends on Priority 2 being solid at BF16 KV first.** FP4 KV shrinks tiles ~4×; same SMEM budget supports more pipeline stages.
| KV dtype | Tile size (hd=512) | Stages fitting 192 KB |
|---|---:|---:|
| BF16 | 128 KB | 2 |
| FP8 | 64 KB | 4 |
| FP4 | ~36 KB | 6 |
At 1M-context decode where KV reads dominate, deeper pipelines hide more TMA latency.
**Implementation:**
- TMA loads FP4 NoPE dims (`e2m1_x2` packed) to SMEM slot 0.
- TMA loads BF16 RoPE dims to SMEM slot 1.
- TMA loads FP8 scale factors to SMEM slot 2.
- SMEM dequant FP4 → BF16 vectorized (`* FP8_scale`, 16-element microblocks).
- Concatenate `[NoPE, RoPE]` in SMEM.
- MMA reads contiguous BF16 from SMEM.
**Test:** FP4+BF16 split input → identical output to pure BF16 input. Dequant must be transparent.
---
## Priority 8 (conditional on P1 profile): Per-kt O rescale fix
**Only justified if Priority 1 profiling shows Python KV merge overhead > 5% of decode latency.** Otherwise defer indefinitely — the current correct workaround ships.
There are three possible paths. None is a small change.
### Path A: CUTLASS atom replication
The CUTLASS C++ Blackwell FMHA does a TMEM round-trip for O rescale using `SM100_TMEM_LOAD_32dp32b_4x_atom` and `SM100_TMEM_STORE_32dp32b_4x_atom`, which are paired by hardware design. So in principle TMEM round-trip is possible. The question is whether CuTeDSL Python exposes the specific atom variants and layout configuration CUTLASS uses, and whether they can be paired correctly through `make_tmem_copy`.
**Steps:**
- [ ] Read `/root/cutlass/.../blackwell/kernel/attention/fmha/fmha.py` (or equivalent C++ reference) and document the exact atom + repetition + tensor layout used for `correction_rescale`.
- [ ] Enumerate what CuTeDSL Python exposes: `dir(tcgen05.copy)`, available `LdNxNbOp` / `StNxNbOp` variants, what `Repetition(N)` controls.
- [ ] Identify the difference between current `Ld32x32bOp(Repetition(16))` + `St32x32bOp(Repetition(16))` and whatever CUTLASS uses.
- [ ] Build a minimal NO-OP round-trip test (load O, store back unchanged) with the candidate atom configuration. Verify cos = 1.0.
- [ ] If NO-OP passes, retest with `* acc_scale` modification.
**Risk:** CuTeDSL Python may not expose the necessary atom variants, or may not allow the layout configuration CUTLASS uses. In that case, escalate to Path B or C.
**Effort if it works:** 24 days investigation + 12 days porting.
### Path B: O accumulator in registers, manual PV
Restructure FMHA so the PV accumulator is register-resident, not TMEM-resident. Each kt: read V from SMEM, read P from TMEM/SMEM, compute one-shot PV (no accumulate) writing to a temporary TMEM region, then load to registers and add to register-resident running O (with acc_scale applied to the running O before the add).
**Implications:**
- Register pressure is severe at hd=512. 512 FP32 per row × 1 row per thread × 128 threads = 128 KB of registers just for O. Possible but tight.
- PV without TMEM accumulate is a non-standard MMA usage. May need to use a smaller PV tile and accumulate in registers across sub-tiles.
- Loses the natural MMA-accumulator pipeline overlap.
**Effort:** 12 weeks. High risk of regressing hd=64/128/256 paths during the refactor.
### Path C: O accumulator in SMEM
Variant of Path B with O in SMEM instead of registers. PV writes to a temporary TMEM region, gets loaded to registers, applies acc_scale * existing_SMEM_O + new_PV, stores back to SMEM. Final epilogue reads from SMEM.
**Implications:**
- SMEM budget tightens significantly (need O in SMEM = ~64 KB at hd=512).
- Adds SMEM read/write pressure on every kt.
- May require dropping kv_stage to 1 across the board.
**Effort:** ~1 week. Lower risk than Path B but bigger perf impact (more SMEM traffic).
### Recommended order if profiling demands a fix
1. Try Path A first — least invasive, may just work with the right atom config.
2. If Path A confirmed impossible in CuTeDSL Python, try Path C (SMEM-resident O) before Path B (register-resident O). SMEM gives more headroom at hd=512.
3. Path B only if both fail and the perf gap is truly critical.
---
## Priority 9: hd=512 single-kernel fix
**Currently blocked.** MLIR optimizer hangs > 3 hours on the hd=512 kernel. Tracer completes in 0.8s — kernel is structurally correct.
Decode works via head-packed M with `pv_n_tile=128` and `n_k_sub_tiles=2`, so this is only a prefill efficiency issue. Lower priority than the chain above.
**Options:**
1. Pre-compile cubin offline (accept 12 hour compile, cache result).
2. Add no-softmax mode emitting raw S to GMEM; call twice for k_sub=0/1, accumulate in Python, softmax once externally.
3. Write hd=512 path in raw CUTLASS C++. Bypasses CuTeDSL MLIR entirely. Most realistic if NVIDIA can't fix the optimizer.
4. Report CuTeDSL MLIR optimizer bug to NVIDIA.
---
## Priority 10: Indexer FP4 tensor-core scoring (Stage F)
Paper §5.2.1: *"the QK path in the indexer of CSA, where QK activations are cached, loaded, and multiplied entirely in FP4."*
Current indexer (`dsv4/kernels/cuda/indexer_score_topk.cu`): scalar FP32 dot products, no tensor cores, spinlock-protected shared-memory heap. Single largest perf gap in the codebase. At 1M-context decode it scores ~250K compressed entries per query token — the spinlock heap will not scale to top_k=1024.
**Target:** port DeepGEMM `fp8_paged_mqa_logits` to FP4 inputs with `tcgen05.mma.kind=mxf4nvf4`. Plus per-warp partial top-k merged with a final reduction tree (or radix-select). Plus FP32→BF16 score quantization (paper claims 2× speedup on top-k selector at 99.7% recall).
**Scope:** 23 weeks. Stage F. Do not start until the FP4 epilogue patterns from Priorities 3 and 6 are established — they inform the indexer's FP4 load + score paths.
---
## Build order — recommended sequencing
```
Priority 1 (PROFILE) ──► gates Priority 8
Priority 2 (one-way ─┼─► unblocks Priority 4 (multi-CTA)
final epilogue) │ unblocks Priority 6 (FP4 fuse in FMHA)
Priority 3 (NVFP4-1.1) ─┴── parallel, independent
[verify hd regressions]
Priority 4 (D2 multi-CTA grid)
Priority 5 (Stage E production extraction)
Priority 6 (NVFP4-1.2 FP4 fuse in FMHA output)
Priority 7 (NVFP4-2 FP4 KV pipeline)
Priority 8 (per-kt rescale fix — ONLY if P1 says it matters)
Priority 9 (hd=512 fix — only if prefill efficiency demands)
Priority 10 (indexer FP4 tensor-core scoring) — Stage F
```
**Key change from the previous version:** Priority 1 is now a profiling task, not an engineering task. Priority 2 is scoped honestly — it's the one-way path only, and it does *not* fix the per-kt rescale. The per-kt rescale is Priority 8, conditional, and three paths deep because there is no easy fix.
---
## Speculative — beyond what the V4 paper validated
Listed for completeness. **Do not implement without explicit sign-off.**
1. **NVFP4 compressed KV NoPE dims.** Paper validated FP8; FP4 would halve cache again. Risk: compounds quantization noise on already-lossy compressed KV.
2. **MXFP4 vs NVFP4 for indexer scoring.** Not validated for indexer specifically.
3. **NVFP4 for full attention Q×K^T GEMM.** Closed. Cos 0.86 vs FP32 in earlier tests. Attention stays BF16/FP32.
4. **Per-token FP8 activation scaling in FMHA.** Not validated. Out of scope.
5. **2:4 structured sparsity on FP4 expert weights.** V4 not trained with structured sparsity. Off the table for the released checkpoint.
6. **NVFP4 LM head + MTP head.** Big VRAM win (~1.4 GB saved on Pro). Modest quality risk on rare-token logits. Test against held-out eval before shipping.
---
## Key numbers to remember
| Config | n_h | top_k | s_k decode | n_kv_tiles | Multi-tile? |
|---|---:|---:|---:|---:|:---|
| Flash decode | 64 | 512 | 640 | 5 | YES |
| Pro decode | 128 | 1024 | 1152 | 9 | YES |
| Current single-tile test | 1 | — | 128 | 1 | NO |
Production decode needs the multi-tile path. Today's Python KV merge ships correct results at the cost of 59 launches per step. Whether that cost matters is what Priority 1 measures.