Files
nvfp4-megamoe-kernel/NEXT_PRIORITIES.md
biondizzle 2649488d13 P5: in-kernel multi-KV-tile FA2 online softmax in fmha_6warp_multihead.cuh
- Kernel loops over KV tiles internally with running max/sum rescale
- SMEM accumulator sOacc[hd] replaces TMEM accumulation across tiles
- P is UN-NORMALIZED for multi-tile (exp(s-max), not /sum)
- Per KV tile: QK→softmax→PV→TMEM→read→add to sOacc
- Final: O = sOacc / running_sum
- Single tile (n_kv_tiles=1): same as before, no rescale
- Updated CAPI, Python loader, production.py fast path
- Added multi-tile test cases (N=256, 512)
2026-05-30 08:46:09 +00:00

12 KiB
Raw Blame History

NEXT PRIORITIES — verified against code, not against status docs

Why this file exists: the agent's CURRENT_ISSUE.md tracks the thing it's currently building, not the state of the production path. Those have diverged. This doc resets the priorities against the actual code in dsv4/.

Method: every "current state" claim below was read directly off the source. Where the agent's notes disagree, the code wins. Per doctrine rule 3.


Verified state (read off the code, 2026-05-30)

Item Claimed Verified Evidence
Indexer FP4 dequant (P1) "fixed" DONE dsv4/kernels/indexer/indexer_score_topk.cu:41 has E2M1_LUT[8] = {0, 0.5, 1, 1.5, 2, 3, 4, 6}. Second copy in kernels/cuda/ also fixed (line 17). Live path via score_topk.py:29 builds the indexer copy.
Python KV merge killed (P2) "milestones 45 in progress" NOT DONE production.py:236296 still has the segment loop, torch.cuda.synchronize() at :279 inside the inner loop, eager exp/log merge at :286294. Nothing wired to any 6-warp variant.
6-warp raw-CUDA FMHA "multihead milestone 5 done" WORKS IN ISOLATION, T=1 decode, HD∈{16,64,128,256}, single KV tile, cos 0.999997+. Genuinely Blackwell-native: tcgen05.mma SS, UMMA descriptors, TMEM alloc + tcgen05.ld.sync.aligned.32x32b.x8.b32. Not called from production.
TMA descriptor blocker "INVALID_VALUE on B200" PARTIALLY DEMYSTIFIED docs/cuda13_tma_notes.md and memory/2026-05-29-tma-async.md: CUDA 13 needs byte strides, not element strides. Descriptors now create OK. But cp.async.bulk.tensor.{2d,3d} hangs — mbarrier never signals. Root cause "unknown" per agent.
Multi-row softmax T>32 "blocked on TMEM read" ⚠ Same fmha_6warp_multihead.cuh:196218 softmax is single-row (warp 0, row 0). Agent says fix is 16x256b.x1 instead of 32x32b.x8. That's a guess until the TMEM column layout is printed and confirmed.
FMHA file count n/a 🚨 7 .cuh variants fmha_6warp.cuh, _multihead, _multirow, _tma, _tma_multirow, _tma_multitile, _tma_multirow_multitile. Plus fmha_sm100_tc.cuh, fmha_epilogue_sm100.cuh, fmha_tma.cuh, fmha_umma_desc.cuh. None integrated to production. Going wide instead of deep.
FP4 attention epilogue "blocked" Correctly blocked fmha_common.cuh:32, fmha_epilogue_sm100.cuh:27, fmha.py:51 all note dependence on epilogue rewrite. Real dependency, not procrastination.

The pattern. P1 actually shipped. Everything else is exploration without integration. Seven .cuh variants were forked along three optimization axes (TMA, multirow, multitile), but the production path has not advanced one inch — same Python merge, same per-tile sync, same launch count per decoded token. This is the failure mode you flagged: the agent reports milestones inside a sandbox while the hot path it's supposed to fix is untouched.


Priority order (do these in this sequence — do not parallelize)

P3 — Wire fmha_6warp_multihead.cuh into production.py, decode-only DONE

Shipped in commits 1e6adf5..6421f7c (2026-05-30).

  • fmha_multihead_capi.cu: pure C API, compiled with nvcc -arch=sm_100a
  • fmha_multihead_op.py: ctypes loader + nvcc precompile + custom_op
  • production.py: fast path for T=1, n_segments=1, hd∈{64,128,256}
  • Grid: dim3(1, n_h, batch) — 1 CTA per (head, batch)
  • MQA/GQA: K/V repeat_interleave for shared heads
  • Fixed: double normalization bug in epilogue (P was already normalized)
  • Fixed: torch JIT compiles with -arch=sm_100 (not sm_100a) — use nvcc+ctypes
  • Test: test_p3_fast_decode.py — 12 raw + 5 API configs, cos >= 0.999990

P4 — Resolve the TMA hang RESOLVED

Root cause: GMEM pointer misalignment. TMA requires 128-byte aligned GMEM addresses. With proper alignment, ALL descriptor configs work (no swizzle, swizzle_128B, OOB_FILL_ZERO). The bit-21 workaround was NOT needed. See docs/p4_tma_hang_resolution.md.

This is the most-leverage move, not the most-exciting one. The kernel works in isolation for the exact shape that dominates decode (T=1, single KV tile, multi-head, HD∈{64,128,256}). Wire it. Stop forking variants.

Definition of done:

  1. production.py:_run_fmha_segmented has a fast path: if T == 1 and n_segments == 1, call the 6-warp kernel via a torch custom op. Single launch. No torch.cuda.synchronize() on the hot path. No per-tile allocations.
  2. Numerical parity gate: cos ≥ 0.999998 against the current Python-merge path on the cases it handles. Not "looks fine" — bitwise diffable test.
  3. Launch count measurement before/after on one decoded V4-Pro CSA layer. Record with Nsight Systems or cudaLaunchKernel counter. Target on the new path: 1 kernel launch, 0 cudaDeviceSynchronize on the hot path. This is the number that proves P2 progress; cosine alone does not.
  4. The slow path (multi-segment, prefill T>1) stays as-is for now. Don't touch it until the decode path is integrated and measured.

Failure modes to watch for (call them out if you see them):

  • Agent creates an 8th .cuh variant. The answer is "integrate, don't fork."
  • Agent regresses the cosine to "make integration easier." Parity is the gate.
  • Agent removes the Python fallback before the fast path covers all shapes used.

P4 — Resolve the TMA hang with print-and-diff, NOT another guess

Memory note ends with "root cause unknown" and three speculative options. That's the prompt to apply doctrine rule 3, not to pick one.

Definition of done:

  1. Dump the descriptor bytes from a working CuTeDSL TMA path (the existing CuTeDSL FMHA already runs TMA correctly — that's the oracle).
  2. Dump the descriptor bytes from the raw-CUDA cuTensorMapEncodeTiled path that hangs.
  3. memcmp them. Document the byte-level differences in a .md paper trail alongside cuda13_tma_notes.md. The hang is almost certainly one of: swizzle mode, fill mode, interleave layout, or oobFill — the 5 enum fields the API takes after the strides. Print every field of both descriptors side by side.
  4. Once they match: re-run, confirm no hang. If they don't match: code to the diff, not to a new theory.

Failure modes to watch for:

  • Agent reaches for "manually construct TMA descriptor bytes" (option B from the memory note) without doing the diff first. That's a bigger guess, not a smaller one.
  • Agent declares this "deferred" and moves on. TMA is on the critical path for prefill / long-context throughput; "decode works without TMA" is true but doesn't generalize.

P5 — In-kernel online softmax across KV tiles (the real P2)

After P3, the decode fast path bypasses the Python merge for n_segments==1. But CSA with top_k=1024 always has n_segments=8, so the moment top_k > 128 we're back on the slow Python path. This priority extends the 6-warp kernel to loop KV tiles internally with FlashAttention-2 running max/sum.

Definition of done:

  1. The 6-warp kernel (the same file, not a new variant) accepts a kv_tile_count parameter and loops over it internally, maintaining (row_max, row_sum) rescale across tiles. Standard FA2 shape.
  2. Single launch handles any n_segments. The Python merge in production.py for multi-tile is deleted, not "kept as fallback."
  3. Parity gate: cos ≥ 0.999998 vs the now-deleted Python merge, captured in a regression test before deletion.
  4. Launch count on V4-Pro CSA layer: still 1, regardless of top_k. The cudaDeviceSynchronize calls go to zero, period.

Failure modes to watch for:

  • Agent writes an fmha_6warp_multikvtile.cuh instead of extending the chosen file.
  • Agent moves the rescale to host-side "for clarity." The rescale must live in the kernel or the launch-count guarantee breaks.

P6 — One-way TMEM→regs→SMEM→GMEM epilogue, with FP4 hook

Unlocks NVFP4-1.2 (FP4 output fusion) and multi-CTA grids. Pattern already runs correctly in dsv4/kernels/gemm/dense.py; the symbols are already imported into fmha.py (lines 7172) and unused. The work is wiring, not invention.

Definition of done:

  1. The 6-warp kernel's epilogue uses epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition + TMA store, with an epilogue_op lambda slot (constexpr) for fusion. Same shape as MoE GEMM.
  2. epilogue_op = lambda x: x ships as default. FP4 pack lambda lands as a flag, off by default, behind its own test.
  3. Multi-CTA grid smoke test: M ≥ 256 prefill, flat_divide coords accepted, no crash, correct numerics.

Failure modes to watch for:

  • Agent re-introduces epilogue_tma_store "because it's simpler." That's the blocker we're removing — don't re-add it.
  • Agent enables FP4 pack at the same time as the epilogue rewrite. Two changes, one test. Land the epilogue first, FP4 fusion second.

P7 — Multi-row softmax T>32, by printing the TMEM column layout

The agent's plan ("use 16x256b.x1") is a guess. May be right; may not be. Before changing the instruction:

Definition of done:

  1. Print the TMEM column map for HD=256, T=128 case: for each (warp, lane, tmem column), which (row, col) of S does it own? Write the observed map into a .md doc.
  2. Pick the TMEM load instruction that matches the observed map. If it's 16x256b.x1, fine — but with the table backing the choice.
  3. Parity gate: cos ≥ 0.999998 for T∈{1, 32, 64, 128} all in the same kernel.

Failure modes to watch for:

  • Agent picks the instruction first, then "interprets the layout to match." Layout first, instruction second.

P8 — Consolidate: delete 6 of the 7 6-warp variants

After P3P7, exactly one variant should exist. The other six are landmines for the next agent (and for you when you context-switch back in three weeks).

Definition of done: ls dsv4/kernels/attention/fmha_6warp*.cuh returns one file. Tests updated to point at it. git rm for the rest. No "archive/" folder.


What is not on this list, and why

  • MoE / GEMM stack: already production-grade per audit. Don't touch.
  • mHC implementation: correct per paper, working.
  • CSA compressor (overlapped 2m): correct per paper, working.
  • Router (sqrt(softplus) + aux-loss-free): correct per paper, working.
  • MegaMoE EP overlap (dispatch/combine waves): correctly scoped out — that's a multi-node EP concern, not a single-GPU kernel concern.
  • Reasoning effort modes, Quick Instruction, OPD: post-training surface, not kernel surface.

DOCTRINE — every priority above is gated on these, no exceptions

  1. CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python. The Python KV merge in production.py is the cautionary tale. The 6-warp raw-CUDA kernel is the correct fallback for that wall — but it has to be integrated, not exhibited.

  2. Raw CUDA ≠ scalar math. Every priority above keeps tcgen05 / UMMA / TMEM / TMA / warp-level reductions. No scalar dot products as "temporary simplification." That's how we got the indexer LUT bug — a "temporary" scalar oracle that was wrong and trusted as a reference.

  3. Print, don't guess. Code to the data. Two specific applications this round:

    • P4 (TMA hang): dump descriptor bytes, diff, code to the diff. Do not guess at "format mismatch."
    • P7 (TMEM layout): print the observed (warp, lane) → (row, col) map, pick the instruction from the map. Do not pick the instruction first.
  4. Integration over exploration. Seven .cuh variants is the diagnostic. The priority order above is deliberately "wire the one that works" before "make it handle more shapes." A working kernel not on the production path is worth zero. Resist the urge to fork a new file for each new concern; extend the chosen file or step back to plan.

  5. Falsifiable gates only. Every "definition of done" above has a number (cosine, launch count, file count) or a binary check. "Looks fine," "milestone complete," and "in progress" are not gates. If a status doc says "DONE" without a number next to it, the doctrine reads it as "NOT DONE."