- Kernel loops over KV tiles internally with running max/sum rescale - SMEM accumulator sOacc[hd] replaces TMEM accumulation across tiles - P is UN-NORMALIZED for multi-tile (exp(s-max), not /sum) - Per KV tile: QK→softmax→PV→TMEM→read→add to sOacc - Final: O = sOacc / running_sum - Single tile (n_kv_tiles=1): same as before, no rescale - Updated CAPI, Python loader, production.py fast path - Added multi-tile test cases (N=256, 512)
12 KiB
NEXT PRIORITIES — verified against code, not against status docs
Why this file exists: the agent's CURRENT_ISSUE.md tracks the thing it's
currently building, not the state of the production path. Those have diverged.
This doc resets the priorities against the actual code in dsv4/.
Method: every "current state" claim below was read directly off the source. Where the agent's notes disagree, the code wins. Per doctrine rule 3.
Verified state (read off the code, 2026-05-30)
| Item | Claimed | Verified | Evidence |
|---|---|---|---|
| Indexer FP4 dequant (P1) | "fixed" | ✅ DONE | dsv4/kernels/indexer/indexer_score_topk.cu:41 has E2M1_LUT[8] = {0, 0.5, 1, 1.5, 2, 3, 4, 6}. Second copy in kernels/cuda/ also fixed (line 17). Live path via score_topk.py:29 builds the indexer copy. |
| Python KV merge killed (P2) | "milestones 4–5 in progress" | ❌ NOT DONE | production.py:236–296 still has the segment loop, torch.cuda.synchronize() at :279 inside the inner loop, eager exp/log merge at :286–294. Nothing wired to any 6-warp variant. |
| 6-warp raw-CUDA FMHA | "multihead milestone 5 done" | ⚠ WORKS IN ISOLATION, T=1 decode, HD∈{16,64,128,256}, single KV tile, cos 0.999997+. Genuinely Blackwell-native: tcgen05.mma SS, UMMA descriptors, TMEM alloc + tcgen05.ld.sync.aligned.32x32b.x8.b32. Not called from production. |
|
| TMA descriptor blocker | "INVALID_VALUE on B200" | ⚠ PARTIALLY DEMYSTIFIED | docs/cuda13_tma_notes.md and memory/2026-05-29-tma-async.md: CUDA 13 needs byte strides, not element strides. Descriptors now create OK. But cp.async.bulk.tensor.{2d,3d} hangs — mbarrier never signals. Root cause "unknown" per agent. |
| Multi-row softmax T>32 | "blocked on TMEM read" | ⚠ Same | fmha_6warp_multihead.cuh:196–218 softmax is single-row (warp 0, row 0). Agent says fix is 16x256b.x1 instead of 32x32b.x8. That's a guess until the TMEM column layout is printed and confirmed. |
| FMHA file count | n/a | 🚨 7 .cuh variants | fmha_6warp.cuh, _multihead, _multirow, _tma, _tma_multirow, _tma_multitile, _tma_multirow_multitile. Plus fmha_sm100_tc.cuh, fmha_epilogue_sm100.cuh, fmha_tma.cuh, fmha_umma_desc.cuh. None integrated to production. Going wide instead of deep. |
| FP4 attention epilogue | "blocked" | Correctly blocked | fmha_common.cuh:32, fmha_epilogue_sm100.cuh:27, fmha.py:51 all note dependence on epilogue rewrite. Real dependency, not procrastination. |
The pattern. P1 actually shipped. Everything else is exploration without integration. Seven .cuh variants were forked along three optimization axes (TMA, multirow, multitile), but the production path has not advanced one inch — same Python merge, same per-tile sync, same launch count per decoded token. This is the failure mode you flagged: the agent reports milestones inside a sandbox while the hot path it's supposed to fix is untouched.
Priority order (do these in this sequence — do not parallelize)
P3 — Wire fmha_6warp_multihead.cuh into production.py, decode-only ✅ DONE
Shipped in commits 1e6adf5..6421f7c (2026-05-30).
- fmha_multihead_capi.cu: pure C API, compiled with nvcc -arch=sm_100a
- fmha_multihead_op.py: ctypes loader + nvcc precompile + custom_op
- production.py: fast path for T=1, n_segments=1, hd∈{64,128,256}
- Grid: dim3(1, n_h, batch) — 1 CTA per (head, batch)
- MQA/GQA: K/V repeat_interleave for shared heads
- Fixed: double normalization bug in epilogue (P was already normalized)
- Fixed: torch JIT compiles with -arch=sm_100 (not sm_100a) — use nvcc+ctypes
- Test: test_p3_fast_decode.py — 12 raw + 5 API configs, cos >= 0.999990
P4 — Resolve the TMA hang ✅ RESOLVED
Root cause: GMEM pointer misalignment. TMA requires 128-byte aligned GMEM addresses. With proper alignment, ALL descriptor configs work (no swizzle, swizzle_128B, OOB_FILL_ZERO). The bit-21 workaround was NOT needed. See docs/p4_tma_hang_resolution.md.
This is the most-leverage move, not the most-exciting one. The kernel works in isolation for the exact shape that dominates decode (T=1, single KV tile, multi-head, HD∈{64,128,256}). Wire it. Stop forking variants.
Definition of done:
production.py:_run_fmha_segmentedhas a fast path: ifT == 1andn_segments == 1, call the 6-warp kernel via a torch custom op. Single launch. Notorch.cuda.synchronize()on the hot path. No per-tile allocations.- Numerical parity gate:
cos ≥ 0.999998against the current Python-merge path on the cases it handles. Not "looks fine" — bitwise diffable test. - Launch count measurement before/after on one decoded V4-Pro CSA layer.
Record with Nsight Systems or
cudaLaunchKernelcounter. Target on the new path: 1 kernel launch, 0cudaDeviceSynchronizeon the hot path. This is the number that proves P2 progress; cosine alone does not. - The slow path (multi-segment, prefill T>1) stays as-is for now. Don't touch it until the decode path is integrated and measured.
Failure modes to watch for (call them out if you see them):
- Agent creates an 8th .cuh variant. The answer is "integrate, don't fork."
- Agent regresses the cosine to "make integration easier." Parity is the gate.
- Agent removes the Python fallback before the fast path covers all shapes used.
P4 — Resolve the TMA hang with print-and-diff, NOT another guess
Memory note ends with "root cause unknown" and three speculative options. That's the prompt to apply doctrine rule 3, not to pick one.
Definition of done:
- Dump the descriptor bytes from a working CuTeDSL TMA path (the existing CuTeDSL FMHA already runs TMA correctly — that's the oracle).
- Dump the descriptor bytes from the raw-CUDA
cuTensorMapEncodeTiledpath that hangs. memcmpthem. Document the byte-level differences in a.mdpaper trail alongsidecuda13_tma_notes.md. The hang is almost certainly one of: swizzle mode, fill mode, interleave layout, or oobFill — the 5 enum fields the API takes after the strides. Print every field of both descriptors side by side.- Once they match: re-run, confirm no hang. If they don't match: code to the diff, not to a new theory.
Failure modes to watch for:
- Agent reaches for "manually construct TMA descriptor bytes" (option B from the memory note) without doing the diff first. That's a bigger guess, not a smaller one.
- Agent declares this "deferred" and moves on. TMA is on the critical path for prefill / long-context throughput; "decode works without TMA" is true but doesn't generalize.
P5 — In-kernel online softmax across KV tiles (the real P2)
After P3, the decode fast path bypasses the Python merge for n_segments==1.
But CSA with top_k=1024 always has n_segments=8, so the moment top_k > 128
we're back on the slow Python path. This priority extends the 6-warp kernel to
loop KV tiles internally with FlashAttention-2 running max/sum.
Definition of done:
- The 6-warp kernel (the same file, not a new variant) accepts a
kv_tile_countparameter and loops over it internally, maintaining(row_max, row_sum)rescale across tiles. Standard FA2 shape. - Single launch handles any
n_segments. The Python merge inproduction.pyfor multi-tile is deleted, not "kept as fallback." - Parity gate:
cos ≥ 0.999998vs the now-deleted Python merge, captured in a regression test before deletion. - Launch count on V4-Pro CSA layer: still 1, regardless of
top_k. ThecudaDeviceSynchronizecalls go to zero, period.
Failure modes to watch for:
- Agent writes an
fmha_6warp_multikvtile.cuhinstead of extending the chosen file. - Agent moves the rescale to host-side "for clarity." The rescale must live in the kernel or the launch-count guarantee breaks.
P6 — One-way TMEM→regs→SMEM→GMEM epilogue, with FP4 hook
Unlocks NVFP4-1.2 (FP4 output fusion) and multi-CTA grids. Pattern already runs
correctly in dsv4/kernels/gemm/dense.py; the symbols are already imported into
fmha.py (lines 71–72) and unused. The work is wiring, not invention.
Definition of done:
- The 6-warp kernel's epilogue uses
epilogue_tmem_copy_and_partition+epilogue_smem_copy_and_partition+ TMA store, with anepilogue_oplambda slot (constexpr) for fusion. Same shape as MoE GEMM. epilogue_op = lambda x: xships as default. FP4 pack lambda lands as a flag, off by default, behind its own test.- Multi-CTA grid smoke test: M ≥ 256 prefill, flat_divide coords accepted, no crash, correct numerics.
Failure modes to watch for:
- Agent re-introduces
epilogue_tma_store"because it's simpler." That's the blocker we're removing — don't re-add it. - Agent enables FP4 pack at the same time as the epilogue rewrite. Two changes, one test. Land the epilogue first, FP4 fusion second.
P7 — Multi-row softmax T>32, by printing the TMEM column layout
The agent's plan ("use 16x256b.x1") is a guess. May be right; may not be.
Before changing the instruction:
Definition of done:
- Print the TMEM column map for HD=256, T=128 case: for each (warp, lane,
tmem column), which (row, col) of S does it own? Write the observed map into
a
.mddoc. - Pick the TMEM load instruction that matches the observed map. If it's
16x256b.x1, fine — but with the table backing the choice. - Parity gate:
cos ≥ 0.999998for T∈{1, 32, 64, 128} all in the same kernel.
Failure modes to watch for:
- Agent picks the instruction first, then "interprets the layout to match." Layout first, instruction second.
P8 — Consolidate: delete 6 of the 7 6-warp variants
After P3–P7, exactly one variant should exist. The other six are landmines for the next agent (and for you when you context-switch back in three weeks).
Definition of done: ls dsv4/kernels/attention/fmha_6warp*.cuh returns one
file. Tests updated to point at it. git rm for the rest. No "archive/" folder.
What is not on this list, and why
- MoE / GEMM stack: already production-grade per audit. Don't touch.
- mHC implementation: correct per paper, working.
- CSA compressor (overlapped 2m): correct per paper, working.
- Router (
sqrt(softplus)+ aux-loss-free): correct per paper, working. - MegaMoE EP overlap (dispatch/combine waves): correctly scoped out — that's a multi-node EP concern, not a single-GPU kernel concern.
- Reasoning effort modes, Quick Instruction, OPD: post-training surface, not kernel surface.
DOCTRINE — every priority above is gated on these, no exceptions
-
CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python. The Python KV merge in
production.pyis the cautionary tale. The 6-warp raw-CUDA kernel is the correct fallback for that wall — but it has to be integrated, not exhibited. -
Raw CUDA ≠ scalar math. Every priority above keeps
tcgen05/ UMMA / TMEM / TMA / warp-level reductions. No scalar dot products as "temporary simplification." That's how we got the indexer LUT bug — a "temporary" scalar oracle that was wrong and trusted as a reference. -
Print, don't guess. Code to the data. Two specific applications this round:
- P4 (TMA hang): dump descriptor bytes, diff, code to the diff. Do not guess at "format mismatch."
- P7 (TMEM layout): print the observed (warp, lane) → (row, col) map, pick the instruction from the map. Do not pick the instruction first.
-
Integration over exploration. Seven .cuh variants is the diagnostic. The priority order above is deliberately "wire the one that works" before "make it handle more shapes." A working kernel not on the production path is worth zero. Resist the urge to fork a new file for each new concern; extend the chosen file or step back to plan.
-
Falsifiable gates only. Every "definition of done" above has a number (cosine, launch count, file count) or a binary check. "Looks fine," "milestone complete," and "in progress" are not gates. If a status doc says "DONE" without a number next to it, the doctrine reads it as "NOT DONE."