Files
nvfp4-megamoe-kernel/archived_plans/NEXT_PRIORITIES.md
biondizzle 4b9eed02e1 Cleanup C1-C7: delete dead CuTeDSL FMHA, test probes, scratch files
- Deleted fmha.py (CuTeDSL slow path), FmhaKernel, Python KV merge
- Deleted fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh
- Moved fmha_qk_verify.cuh to tests/unit/qk_verify_kernel.cuh
- Deleted decode_sparse.py, decode_swa.py, kernels/decode/
- Deleted 46 test_d*.py probes, test_smem_*, test_cotiled_*, test_tmem_*,
  test_smem_p_*, test_ultra_minimal, test_fmha_pv16, test_working_softmax_maybe
- Deleted root scratch: debug_linear.py, test_mapping.py, run_router_tests.py
- Moved archive/ to archived_plans/code_archive/
- Rewrote production.py: single fast path via 6-warp multi-tile kernel
- Added STATUS.md, audit_attention_live.md
- Moved NEXT_PRIORITIES*.md to archived_plans/
2026-05-30 21:08:12 +00:00

12 KiB
Raw Blame History

NEXT PRIORITIES — verified against code, not against status docs

Why this file exists: the agent's CURRENT_ISSUE.md tracks the thing it's currently building, not the state of the production path. Those have diverged. This doc resets the priorities against the actual code in dsv4/.

Method: every "current state" claim below was read directly off the source. Where the agent's notes disagree, the code wins. Per doctrine rule 3.


Verified state (read off the code, 2026-05-30)

Item Claimed Verified Evidence
Indexer FP4 dequant (P1) "fixed" DONE dsv4/kernels/indexer/indexer_score_topk.cu:41 has E2M1_LUT[8] = {0, 0.5, 1, 1.5, 2, 3, 4, 6}. Second copy in kernels/cuda/ also fixed (line 17). Live path via score_topk.py:29 builds the indexer copy.
Python KV merge killed (P2) "milestones 45 in progress" NOT DONE production.py:236296 still has the segment loop, torch.cuda.synchronize() at :279 inside the inner loop, eager exp/log merge at :286294. Nothing wired to any 6-warp variant.
6-warp raw-CUDA FMHA "multihead milestone 5 done" WORKS IN ISOLATION, T=1 decode, HD∈{16,64,128,256}, single KV tile, cos 0.999997+. Genuinely Blackwell-native: tcgen05.mma SS, UMMA descriptors, TMEM alloc + tcgen05.ld.sync.aligned.32x32b.x8.b32. Not called from production.
TMA descriptor blocker "INVALID_VALUE on B200" PARTIALLY DEMYSTIFIED docs/cuda13_tma_notes.md and memory/2026-05-29-tma-async.md: CUDA 13 needs byte strides, not element strides. Descriptors now create OK. But cp.async.bulk.tensor.{2d,3d} hangs — mbarrier never signals. Root cause "unknown" per agent.
Multi-row softmax T>32 "blocked on TMEM read" ⚠ Same fmha_6warp_multihead.cuh:196218 softmax is single-row (warp 0, row 0). Agent says fix is 16x256b.x1 instead of 32x32b.x8. That's a guess until the TMEM column layout is printed and confirmed.
FMHA file count n/a 🚨 7 .cuh variants fmha_6warp.cuh, _multihead, _multirow, _tma, _tma_multirow, _tma_multitile, _tma_multirow_multitile. Plus fmha_sm100_tc.cuh, fmha_epilogue_sm100.cuh, fmha_tma.cuh, fmha_umma_desc.cuh. None integrated to production. Going wide instead of deep.
FP4 attention epilogue "blocked" Correctly blocked fmha_common.cuh:32, fmha_epilogue_sm100.cuh:27, fmha.py:51 all note dependence on epilogue rewrite. Real dependency, not procrastination.

The pattern. P1 actually shipped. Everything else is exploration without integration. Seven .cuh variants were forked along three optimization axes (TMA, multirow, multitile), but the production path has not advanced one inch — same Python merge, same per-tile sync, same launch count per decoded token. This is the failure mode you flagged: the agent reports milestones inside a sandbox while the hot path it's supposed to fix is untouched.


Priority order (do these in this sequence — do not parallelize)

P3 — Wire fmha_6warp_multihead.cuh into production.py, decode-only DONE

Shipped in commits 1e6adf5..6421f7c (2026-05-30).

  • fmha_multihead_capi.cu: pure C API, compiled with nvcc -arch=sm_100a
  • fmha_multihead_op.py: ctypes loader + nvcc precompile + custom_op
  • production.py: fast path for T=1, n_segments=1, hd∈{64,128,256}
  • Grid: dim3(1, n_h, batch) — 1 CTA per (head, batch)
  • MQA/GQA: K/V repeat_interleave for shared heads
  • Fixed: double normalization bug in epilogue (P was already normalized)
  • Fixed: torch JIT compiles with -arch=sm_100 (not sm_100a) — use nvcc+ctypes
  • Test: test_p3_fast_decode.py — 12 raw + 5 API configs, cos >= 0.999990

P4 — Resolve the TMA hang RESOLVED

Root cause: GMEM pointer misalignment. TMA requires 128-byte aligned GMEM addresses. With proper alignment, ALL descriptor configs work (no swizzle, swizzle_128B, OOB_FILL_ZERO). The bit-21 workaround was NOT needed. See docs/p4_tma_hang_resolution.md.

This is the most-leverage move, not the most-exciting one. The kernel works in isolation for the exact shape that dominates decode (T=1, single KV tile, multi-head, HD∈{64,128,256}). Wire it. Stop forking variants.

Definition of done:

  1. production.py:_run_fmha_segmented has a fast path: if T == 1 and n_segments == 1, call the 6-warp kernel via a torch custom op. Single launch. No torch.cuda.synchronize() on the hot path. No per-tile allocations.
  2. Numerical parity gate: cos ≥ 0.999998 against the current Python-merge path on the cases it handles. Not "looks fine" — bitwise diffable test.
  3. Launch count measurement before/after on one decoded V4-Pro CSA layer. Record with Nsight Systems or cudaLaunchKernel counter. Target on the new path: 1 kernel launch, 0 cudaDeviceSynchronize on the hot path. This is the number that proves P2 progress; cosine alone does not.
  4. The slow path (multi-segment, prefill T>1) stays as-is for now. Don't touch it until the decode path is integrated and measured.

Failure modes to watch for (call them out if you see them):

  • Agent creates an 8th .cuh variant. The answer is "integrate, don't fork."
  • Agent regresses the cosine to "make integration easier." Parity is the gate.
  • Agent removes the Python fallback before the fast path covers all shapes used.

P4 — Resolve the TMA hang with print-and-diff, NOT another guess

Memory note ends with "root cause unknown" and three speculative options. That's the prompt to apply doctrine rule 3, not to pick one.

Definition of done:

  1. Dump the descriptor bytes from a working CuTeDSL TMA path (the existing CuTeDSL FMHA already runs TMA correctly — that's the oracle).
  2. Dump the descriptor bytes from the raw-CUDA cuTensorMapEncodeTiled path that hangs.
  3. memcmp them. Document the byte-level differences in a .md paper trail alongside cuda13_tma_notes.md. The hang is almost certainly one of: swizzle mode, fill mode, interleave layout, or oobFill — the 5 enum fields the API takes after the strides. Print every field of both descriptors side by side.
  4. Once they match: re-run, confirm no hang. If they don't match: code to the diff, not to a new theory.

Failure modes to watch for:

  • Agent reaches for "manually construct TMA descriptor bytes" (option B from the memory note) without doing the diff first. That's a bigger guess, not a smaller one.
  • Agent declares this "deferred" and moves on. TMA is on the critical path for prefill / long-context throughput; "decode works without TMA" is true but doesn't generalize.

P5 — Integrate multi-tile FMHA into production DONE

Shipped 2026-05-30. Wired the existing D1.5 kernel (fmha_6warp_tma_multirow_multitile.cuh) via fmha_multitile_capi.cu + fmha_multitile_op.py into production.py. 18 integration tests pass.

After P3, the decode fast path bypasses the Python merge for n_segments==1. But CSA with top_k=1024 always has n_segments=8, so the moment top_k > 128 we're back on the slow Python path. This priority extends the 6-warp kernel to loop KV tiles internally with FlashAttention-2 running max/sum.

Definition of done:

  1. The 6-warp kernel (the same file, not a new variant) accepts a kv_tile_count parameter and loops over it internally, maintaining (row_max, row_sum) rescale across tiles. Standard FA2 shape.
  2. Single launch handles any n_segments. The Python merge in production.py for multi-tile is deleted, not "kept as fallback."
  3. Parity gate: cos ≥ 0.999998 vs the now-deleted Python merge, captured in a regression test before deletion.
  4. Launch count on V4-Pro CSA layer: still 1, regardless of top_k. The cudaDeviceSynchronize calls go to zero, period.

Failure modes to watch for:

  • Agent writes an fmha_6warp_multikvtile.cuh instead of extending the chosen file.
  • Agent moves the rescale to host-side "for clarity." The rescale must live in the kernel or the launch-count guarantee breaks.

P6 — One-way TMEM→regs→SMEM→GMEM epilogue, with FP4 hook DONE

Shipped 2026-05-30.

  • fmha_6warp_multihead.cuh: Rewritten epilogue with proper Blackwell pipeline:
    1. TMEM → registers (tcgen05.ld, warp-collective)
    2. epilogue_op in registers (normalize, ENABLE_FP4_EPILOGUE template param)
    3. Registers → SMEM (row-major sO_epi)
    4. SMEM → GMEM (direct write)
  • fmha_6warp_tma_multirow_multitile.cuh: Same epilogue pattern for multi-tile.
  • cp.async.bulk.tensor store (SMEM→GMEM) is NOT available on SM100. CUTLASS SM100 epilogue uses st.global directly.
  • FP4 pack hook: ENABLE_FP4_EPILOGUE template param (off by default).
  • Test: test_p6_tma_epilogue.py — 9 configs ALL PASS, cos >= 0.999990

P7 — Multi-row softmax T>32, by printing the TMEM column layout DONE

Shipped 2026-05-30.

  • docs/p7_tmem_column_layout.md: Verified that tcgen05.ld 32x32b.x8 is correct. Each call reads 8 KV positions for 32 rows. No instruction change needed.
  • The multi-tile kernel already handles T=1..128 with 4 softmax warps.
  • Test: test_p7_multi_row_softmax.py — 10 configs ALL PASS, cos >= 0.999996

P8 — Consolidate: delete 6 of the 7 6-warp variants DONE

Shipped 2026-05-30.

  • Kept: fmha_6warp_tma_multirow_multitile.cuh (THE production kernel)
  • Deleted: fmha_6warp.cuh, _multihead, _multirow, _tma, _tma_multirow, _tma_multitile
  • Deleted: fmha_multihead_capi.cu, fmha_multihead_op.py
  • production.py: Unified dispatch to _dsv4_attention_multitile for all fast-path cases
  • ls dsv4/kernels/attention/fmha_6warp*.cuh returns ONE file

What is not on this list, and why

  • MoE / GEMM stack: already production-grade per audit. Don't touch.
  • mHC implementation: correct per paper, working.
  • CSA compressor (overlapped 2m): correct per paper, working.
  • Router (sqrt(softplus) + aux-loss-free): correct per paper, working.
  • MegaMoE EP overlap (dispatch/combine waves): correctly scoped out — that's a multi-node EP concern, not a single-GPU kernel concern.
  • Reasoning effort modes, Quick Instruction, OPD: post-training surface, not kernel surface.

DOCTRINE — every priority above is gated on these, no exceptions

  1. CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python. The Python KV merge in production.py is the cautionary tale. The 6-warp raw-CUDA kernel is the correct fallback for that wall — but it has to be integrated, not exhibited.

  2. Raw CUDA ≠ scalar math. Every priority above keeps tcgen05 / UMMA / TMEM / TMA / warp-level reductions. No scalar dot products as "temporary simplification." That's how we got the indexer LUT bug — a "temporary" scalar oracle that was wrong and trusted as a reference.

  3. Print, don't guess. Code to the data. Two specific applications this round:

    • P4 (TMA hang): dump descriptor bytes, diff, code to the diff. Do not guess at "format mismatch."
    • P7 (TMEM layout): print the observed (warp, lane) → (row, col) map, pick the instruction from the map. Do not pick the instruction first.
  4. Integration over exploration. Seven .cuh variants is the diagnostic. The priority order above is deliberately "wire the one that works" before "make it handle more shapes." A working kernel not on the production path is worth zero. Resist the urge to fork a new file for each new concern; extend the chosen file or step back to plan.

  5. Falsifiable gates only. Every "definition of done" above has a number (cosine, launch count, file count) or a binary check. "Looks fine," "milestone complete," and "in progress" are not gates. If a status doc says "DONE" without a number next to it, the doctrine reads it as "NOT DONE."