# ISSUE — Priority 2: kill the Python KV merge; one-way in-kernel epilogue + online softmax **Status:** OPEN — structural shortcut on the FMHA decode hot path. **Severity:** HIGH for "production grade." Correctness is fine (cos 0.999998); the problem is this path can never be fast, and it blocks the entire multi-CTA / FP4 fusion chain (D2/P4, NVFP4-1.2, NVFP4-2). **Scope:** `dsv4/kernels/attention/production.py` (`_run_fmha_segmented`), `dsv4/kernels/attention/fmha.py` (epilogue), the WIP `fmha_smem_acc.py`, and the raw-CUDA `fmha_sm100_tc.cuh`. --- ## TL;DR Multi-KV-tile attention is currently done by launching the single-tile CuTeDSL kernel once per (segment × pv_tile), copying un-normalized O + LSE back to the host-side allocator, and merging tiles with **eager PyTorch ops** — with a full **`torch.cuda.synchronize()` inside the inner loop**. That is orchestration, not a kernel. The fix is the standard FlashAttention shape: **loop KV tiles inside the kernel with running max/sum (online softmax), and write the result once through a one-way TMEM→regs→SMEM→GMEM epilogue** — the exact pattern the MoE GEMM already runs correctly on Blackwell. This is the load-bearing example of doctrine rule 1: hitting a CuTeDSL wall is not a license to fall back to Python. --- ## The shortcut, measured (not estimated) `dsv4/kernels/attention/production.py`, `_run_fmha_segmented`: ```python for seg in range(n_segments): ... k_seg = torch.cat([k_seg, torch.zeros(...)]) # alloc + copy to pad seg_o = torch.zeros(M, hd, ...) # per-seg allocs for nt in range(n_pv_tiles): c_tile = torch.zeros(M, pv_n_tile, 1, ...) # per-tile allocs lse_tensor = torch.zeros(M, 1, 1, ...) mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(...) # descriptor rebuild ... # every iteration compiled(mQ, mK, mV, mC, stream, lse=mLSE, ...) # 1 launch torch.cuda.synchronize() # <-- FULL DEVICE SYNC seg_o[:, v_start:v_end] = c_tile[...].float() # eager exp/log/div merge (≈5 launches) o_accum = (e_old*o_accum + e_new*seg_o_norm) / e_sum ``` Per-call cost on the hot path, every iteration: - 1 kernel launch **+ 1 full `cudaDeviceSynchronize`** (serializes the stream — no overlap, no pipelining, no async) - 6+ `torch.zeros` allocations (allocator churn) - `torch.cat` to pad the partial tail segment (alloc + copy) - `ct.from_dlpack(...).mark_layout_dynamic(...)` ×6 (host-side descriptor rebuild) - ~5 eager merge ops per segment **Launch/sync count, V4-Pro CSA decode, single token, single layer:** `top_k = 1024` → `n_segments = 1024/128 = 8`; `hd = 512` → `n_pv_tiles = 512/128 = 4`. → **32 kernel launches + 32 device syncs** per CSA layer, plus ~40 eager merge ops. Across the interleaved CSA/HCA layers of a 61-layer model that is **~1–2k launches and ~1–2k device syncs per decoded token.** At single-digit-µs launch+sync each, that's milliseconds of pure overhead per token — a hard ceiling of tens of tok/s before any real compute runs. At 1M context (KV-read bound, the entire point of V4) this is the bottleneck, full stop. Head-packing (MQA → all 128 Q heads in M) is already correct and keeps this from also multiplying by heads — good. The segment/tile loop and the per-iteration sync are the problem. --- ## The blocker `dsv4/kernels/attention/fmha.py` (~line 620) writes O via `utils.gemm.sm100.epilogue_tma_store`, which reads O straight from TMEM and TMA-stores it. Per the file header (point 2): it **cannot accept flat_divide-based GMEM coordinates**, so you can't fold (batch, head, kv_tile) into a multi-CTA grid. That single epilogue choice is why multi-tile has to be done host-side, and why D2/P4 (multi-CTA) and FP4 output fusion are all stuck behind it. The TMEM round-trip ("D1.5 fundamentally broken") is a **CuTeDSL** atom-pairing limitation, not hardware — CUTLASS C++ pairs Ld/St atoms that work. So the answer is not "merge in Python," it's "use the one-way epilogue (no round-trip) or write the raw CUDA kernel." --- ## The pattern to port FROM (already working on Blackwell) `dsv4/kernels/gemm/dense.py` runs the correct one-way epilogue every MoE GEMM: TMEM → regs → SMEM → GMEM, no round-trip, multi-CTA-safe, with a register slot for fusion. Reuse these symbols (already imported into `fmha.py`): - `sm100_utils.compute_epilogue_tile_shape` (line ~313) — epi subtile - `sm100_utils.make_smem_layout_epi` (line ~361) — SMEM staging layout - `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition` (imported at fmha.py lines 71–72, currently unused on the O path) - `epilogue_op: cutlass.Constexpr = lambda x: x` (line ~405) — the hook where FP4 pack lands later (NVFP4-1.2) This is not new infrastructure. It's wiring the FMHA O-store to the epilogue the MoE kernel already proves out. --- ## Design of the fix **1. Online softmax across KV tiles, in-kernel.** Replace the host-side log-sum-exp merge with a running `(row_max, row_sum)` rescale inside the KV-tile loop — standard FlashAttention-2. The kernel already emits un-normalized O + LSE + row_sums for a single tile; lift the accumulation into the tile loop so one launch consumes all selected KV. No host merge, no per-tile sync, no per-tile allocs. **2. One-way epilogue.** After the tile loop, O lives in TMEM. Copy TMEM→regs→SMEM→GMEM via the MoE epilogue path. No `epilogue_tma_store`, no TMEM round-trip. Normalize by the final `row_sum` in registers before the SMEM store (or keep emitting LSE if external normalization is still wanted — but it shouldn't be needed once the merge is in-kernel). **3. Multi-CTA grid.** With the one-way epilogue accepting flat_divide coords, fold (batch, kv_head_group, [q_tile]) into the grid. This is D2/P4. `use_2cta_instrs` applies for prefill/batched shapes (M ≥ 256); decode single-token stays 1-CTA. **4. Preallocated scratch.** Whatever staging remains is allocated once at warmup and reused (cudagraph-safe), never `torch.zeros` on the hot path. ### Two ramps — pick per doctrine, not per convenience - **Ramp A (preferred if CuTeDSL cooperates):** finish the one-way epilogue in `fmha_smem_acc.py` against the MoE pattern. Status of that file is unclear ("many commits, unclear if working") — first action is to determine whether it works, with a print-and-diff, not by reading the diff log. - **Ramp B (when CuTeDSL walls — the legitimate fallback):** `fmha_sm100_tc.cuh` is already the right shape — `tcgen05.mma` SS for QK, TS for PV, TMEM accumulators, UMMA descriptors, and it already carries `sRowMax/sRowSum` running state for in-kernel online softmax. Finish THIS, not a Python merge. It is raw CUDA C++, it is Blackwell-native, and it has no `epilogue_tma_store` constraint because you control the store. --- ## Test plan — measure, don't eyeball 1. **Correctness parity:** new in-kernel path vs current Python-merge path on the same inputs across the matrix that already passes — hd ∈ {64,128,256,512}, n_segments ∈ {1,2,8}, with SWA mask / causal / sink / n_comp. Gate: cos ≥ 0.999998 (must match the path it replaces, this is a refactor not a numerics change). 2. **Launch + sync count:** Nsight Systems (or a CUPTI launch counter) on one decoded token, V4-Pro CSA layer. Record launches and `cudaDeviceSynchronize` calls before/after. Target: per-layer launches from 32 → 1 (decode), syncs from 32 → 0 on the hot path. 3. **Latency:** per-token decode latency at context 8k / 128k / 1M, before/after. This is the number that says whether Priority 1's "does the overhead matter" question is answered. (Spoiler from the launch math: it matters.) 4. **Unblock check:** confirm the new epilogue accepts a multi-CTA grid (D2 smoke test, M≥256 prefill) and exposes a register slot for the FP4 pack (NVFP4-1.2 stub: `epilogue_op` that amax+packs, behind a flag, off by default). --- ## DOCTRINE REMINDER (this issue is the cautionary tale) 1. **CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python.** The Python KV merge is what rule 1 exists to prevent. `fmha_sm100_tc.cuh` is the correct fallback. A host-side loop with a per-iteration `cudaDeviceSynchronize` is never the production answer. 2. **Raw CUDA ≠ scalar math.** The fallback kernel stays `tcgen05`/UMMA/TMEM/TMA with warp-level softmax reductions. Do not let an agent "simplify" the PV into scalar FMA to get it compiling. 3. **Print, don't guess.** Before touching `fmha_smem_acc.py`, print its actual output and diff against the reference — do not infer its status from commit messages. When wiring the epilogue, print epi tile shape, SMEM layout, TMEM offsets, and the MMA instruction shape at construction, and code to those values. The epilogue is exactly the kind of layout surface where a guess passes a toy test and corrupts at real shapes.