- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
(c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
(dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)
Architecture:
Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
MQA: k_head_stride=0 so all Q heads share same K/V
Single kernel launch, zero cudaDeviceSynchronize on hot path
Normalized output for single-segment decode
8.8 KiB
ISSUE — Priority 2: kill the Python KV merge; one-way in-kernel epilogue + online softmax
Status: OPEN — structural shortcut on the FMHA decode hot path.
Severity: HIGH for "production grade." Correctness is fine (cos 0.999998); the
problem is this path can never be fast, and it blocks the entire multi-CTA / FP4
fusion chain (D2/P4, NVFP4-1.2, NVFP4-2).
Scope: dsv4/kernels/attention/production.py (_run_fmha_segmented),
dsv4/kernels/attention/fmha.py (epilogue), the WIP fmha_smem_acc.py, and the
raw-CUDA fmha_sm100_tc.cuh.
TL;DR
Multi-KV-tile attention is currently done by launching the single-tile CuTeDSL
kernel once per (segment × pv_tile), copying un-normalized O + LSE back to the
host-side allocator, and merging tiles with eager PyTorch ops — with a full
torch.cuda.synchronize() inside the inner loop. That is orchestration, not a
kernel. The fix is the standard FlashAttention shape: loop KV tiles inside the
kernel with running max/sum (online softmax), and write the result once through a
one-way TMEM→regs→SMEM→GMEM epilogue — the exact pattern the MoE GEMM already
runs correctly on Blackwell.
This is the load-bearing example of doctrine rule 1: hitting a CuTeDSL wall is not a license to fall back to Python.
The shortcut, measured (not estimated)
dsv4/kernels/attention/production.py, _run_fmha_segmented:
for seg in range(n_segments):
...
k_seg = torch.cat([k_seg, torch.zeros(...)]) # alloc + copy to pad
seg_o = torch.zeros(M, hd, ...) # per-seg allocs
for nt in range(n_pv_tiles):
c_tile = torch.zeros(M, pv_n_tile, 1, ...) # per-tile allocs
lse_tensor = torch.zeros(M, 1, 1, ...)
mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(...) # descriptor rebuild
... # every iteration
compiled(mQ, mK, mV, mC, stream, lse=mLSE, ...) # 1 launch
torch.cuda.synchronize() # <-- FULL DEVICE SYNC
seg_o[:, v_start:v_end] = c_tile[...].float()
# eager exp/log/div merge (≈5 launches)
o_accum = (e_old*o_accum + e_new*seg_o_norm) / e_sum
Per-call cost on the hot path, every iteration:
- 1 kernel launch + 1 full
cudaDeviceSynchronize(serializes the stream — no overlap, no pipelining, no async) - 6+
torch.zerosallocations (allocator churn) torch.catto pad the partial tail segment (alloc + copy)ct.from_dlpack(...).mark_layout_dynamic(...)×6 (host-side descriptor rebuild)- ~5 eager merge ops per segment
Launch/sync count, V4-Pro CSA decode, single token, single layer:
top_k = 1024 → n_segments = 1024/128 = 8; hd = 512 → n_pv_tiles = 512/128 = 4.
→ 32 kernel launches + 32 device syncs per CSA layer, plus ~40 eager merge ops.
Across the interleaved CSA/HCA layers of a 61-layer model that is ~1–2k launches
and ~1–2k device syncs per decoded token. At single-digit-µs launch+sync each,
that's milliseconds of pure overhead per token — a hard ceiling of tens of tok/s
before any real compute runs. At 1M context (KV-read bound, the entire point of
V4) this is the bottleneck, full stop.
Head-packing (MQA → all 128 Q heads in M) is already correct and keeps this from also multiplying by heads — good. The segment/tile loop and the per-iteration sync are the problem.
The blocker
dsv4/kernels/attention/fmha.py (~line 620) writes O via
utils.gemm.sm100.epilogue_tma_store, which reads O straight from TMEM and
TMA-stores it. Per the file header (point 2): it cannot accept flat_divide-based
GMEM coordinates, so you can't fold (batch, head, kv_tile) into a multi-CTA grid.
That single epilogue choice is why multi-tile has to be done host-side, and why
D2/P4 (multi-CTA) and FP4 output fusion are all stuck behind it.
The TMEM round-trip ("D1.5 fundamentally broken") is a CuTeDSL atom-pairing limitation, not hardware — CUTLASS C++ pairs Ld/St atoms that work. So the answer is not "merge in Python," it's "use the one-way epilogue (no round-trip) or write the raw CUDA kernel."
The pattern to port FROM (already working on Blackwell)
dsv4/kernels/gemm/dense.py runs the correct one-way epilogue every MoE GEMM:
TMEM → regs → SMEM → GMEM, no round-trip, multi-CTA-safe, with a register slot for
fusion. Reuse these symbols (already imported into fmha.py):
sm100_utils.compute_epilogue_tile_shape(line ~313) — epi subtilesm100_utils.make_smem_layout_epi(line ~361) — SMEM staging layoutepilogue_tmem_copy_and_partition/epilogue_smem_copy_and_partition(imported at fmha.py lines 71–72, currently unused on the O path)epilogue_op: cutlass.Constexpr = lambda x: x(line ~405) — the hook where FP4 pack lands later (NVFP4-1.2)
This is not new infrastructure. It's wiring the FMHA O-store to the epilogue the MoE kernel already proves out.
Design of the fix
1. Online softmax across KV tiles, in-kernel. Replace the host-side log-sum-exp
merge with a running (row_max, row_sum) rescale inside the KV-tile loop — standard
FlashAttention-2. The kernel already emits un-normalized O + LSE + row_sums for a
single tile; lift the accumulation into the tile loop so one launch consumes all
selected KV. No host merge, no per-tile sync, no per-tile allocs.
2. One-way epilogue. After the tile loop, O lives in TMEM. Copy
TMEM→regs→SMEM→GMEM via the MoE epilogue path. No epilogue_tma_store, no TMEM
round-trip. Normalize by the final row_sum in registers before the SMEM store
(or keep emitting LSE if external normalization is still wanted — but it shouldn't
be needed once the merge is in-kernel).
3. Multi-CTA grid. With the one-way epilogue accepting flat_divide coords,
fold (batch, kv_head_group, [q_tile]) into the grid. This is D2/P4. use_2cta_instrs
applies for prefill/batched shapes (M ≥ 256); decode single-token stays 1-CTA.
4. Preallocated scratch. Whatever staging remains is allocated once at warmup
and reused (cudagraph-safe), never torch.zeros on the hot path.
Two ramps — pick per doctrine, not per convenience
- Ramp A (preferred if CuTeDSL cooperates): finish the one-way epilogue in
fmha_smem_acc.pyagainst the MoE pattern. Status of that file is unclear ("many commits, unclear if working") — first action is to determine whether it works, with a print-and-diff, not by reading the diff log. - Ramp B (when CuTeDSL walls — the legitimate fallback):
fmha_sm100_tc.cuhis already the right shape —tcgen05.mmaSS for QK, TS for PV, TMEM accumulators, UMMA descriptors, and it already carriessRowMax/sRowSumrunning state for in-kernel online softmax. Finish THIS, not a Python merge. It is raw CUDA C++, it is Blackwell-native, and it has noepilogue_tma_storeconstraint because you control the store.
Test plan — measure, don't eyeball
- Correctness parity: new in-kernel path vs current Python-merge path on the same inputs across the matrix that already passes — hd ∈ {64,128,256,512}, n_segments ∈ {1,2,8}, with SWA mask / causal / sink / n_comp. Gate: cos ≥ 0.999998 (must match the path it replaces, this is a refactor not a numerics change).
- Launch + sync count: Nsight Systems (or a CUPTI launch counter) on one
decoded token, V4-Pro CSA layer. Record launches and
cudaDeviceSynchronizecalls before/after. Target: per-layer launches from 32 → 1 (decode), syncs from 32 → 0 on the hot path. - Latency: per-token decode latency at context 8k / 128k / 1M, before/after. This is the number that says whether Priority 1's "does the overhead matter" question is answered. (Spoiler from the launch math: it matters.)
- Unblock check: confirm the new epilogue accepts a multi-CTA grid (D2 smoke
test, M≥256 prefill) and exposes a register slot for the FP4 pack (NVFP4-1.2
stub:
epilogue_opthat amax+packs, behind a flag, off by default).
DOCTRINE REMINDER (this issue is the cautionary tale)
- CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python. The Python KV merge is what
rule 1 exists to prevent.
fmha_sm100_tc.cuhis the correct fallback. A host-side loop with a per-iterationcudaDeviceSynchronizeis never the production answer. - Raw CUDA ≠ scalar math. The fallback kernel stays
tcgen05/UMMA/TMEM/TMA with warp-level softmax reductions. Do not let an agent "simplify" the PV into scalar FMA to get it compiling. - Print, don't guess. Before touching
fmha_smem_acc.py, print its actual output and diff against the reference — do not infer its status from commit messages. When wiring the epilogue, print epi tile shape, SMEM layout, TMEM offsets, and the MMA instruction shape at construction, and code to those values. The epilogue is exactly the kind of layout surface where a guess passes a toy test and corrupts at real shapes.