Files
biondizzle 1e6adf5e01 P3: wire 6-warp multi-head FMHA decode fast path into production.py
- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
  (c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
  (dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
  Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)

Architecture:
  Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
  MQA: k_head_stride=0 so all Q heads share same K/V
  Single kernel launch, zero cudaDeviceSynchronize on hot path
  Normalized output for single-segment decode
2026-05-30 08:12:23 +00:00

8.8 KiB
Raw Permalink Blame History

ISSUE — Priority 2: kill the Python KV merge; one-way in-kernel epilogue + online softmax

Status: OPEN — structural shortcut on the FMHA decode hot path. Severity: HIGH for "production grade." Correctness is fine (cos 0.999998); the problem is this path can never be fast, and it blocks the entire multi-CTA / FP4 fusion chain (D2/P4, NVFP4-1.2, NVFP4-2). Scope: dsv4/kernels/attention/production.py (_run_fmha_segmented), dsv4/kernels/attention/fmha.py (epilogue), the WIP fmha_smem_acc.py, and the raw-CUDA fmha_sm100_tc.cuh.


TL;DR

Multi-KV-tile attention is currently done by launching the single-tile CuTeDSL kernel once per (segment × pv_tile), copying un-normalized O + LSE back to the host-side allocator, and merging tiles with eager PyTorch ops — with a full torch.cuda.synchronize() inside the inner loop. That is orchestration, not a kernel. The fix is the standard FlashAttention shape: loop KV tiles inside the kernel with running max/sum (online softmax), and write the result once through a one-way TMEM→regs→SMEM→GMEM epilogue — the exact pattern the MoE GEMM already runs correctly on Blackwell.

This is the load-bearing example of doctrine rule 1: hitting a CuTeDSL wall is not a license to fall back to Python.


The shortcut, measured (not estimated)

dsv4/kernels/attention/production.py, _run_fmha_segmented:

for seg in range(n_segments):
    ...
    k_seg = torch.cat([k_seg, torch.zeros(...)])          # alloc + copy to pad
    seg_o   = torch.zeros(M, hd, ...)                      # per-seg allocs
    for nt in range(n_pv_tiles):
        c_tile     = torch.zeros(M, pv_n_tile, 1, ...)     # per-tile allocs
        lse_tensor = torch.zeros(M, 1, 1, ...)
        mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(...) # descriptor rebuild
        ...                                                # every iteration
        compiled(mQ, mK, mV, mC, stream, lse=mLSE, ...)    # 1 launch
        torch.cuda.synchronize()                           # <-- FULL DEVICE SYNC
        seg_o[:, v_start:v_end] = c_tile[...].float()
    # eager exp/log/div merge (≈5 launches)
    o_accum = (e_old*o_accum + e_new*seg_o_norm) / e_sum

Per-call cost on the hot path, every iteration:

  • 1 kernel launch + 1 full cudaDeviceSynchronize (serializes the stream — no overlap, no pipelining, no async)
  • 6+ torch.zeros allocations (allocator churn)
  • torch.cat to pad the partial tail segment (alloc + copy)
  • ct.from_dlpack(...).mark_layout_dynamic(...) ×6 (host-side descriptor rebuild)
  • ~5 eager merge ops per segment

Launch/sync count, V4-Pro CSA decode, single token, single layer: top_k = 1024n_segments = 1024/128 = 8; hd = 512n_pv_tiles = 512/128 = 4. → 32 kernel launches + 32 device syncs per CSA layer, plus ~40 eager merge ops. Across the interleaved CSA/HCA layers of a 61-layer model that is ~12k launches and ~12k device syncs per decoded token. At single-digit-µs launch+sync each, that's milliseconds of pure overhead per token — a hard ceiling of tens of tok/s before any real compute runs. At 1M context (KV-read bound, the entire point of V4) this is the bottleneck, full stop.

Head-packing (MQA → all 128 Q heads in M) is already correct and keeps this from also multiplying by heads — good. The segment/tile loop and the per-iteration sync are the problem.


The blocker

dsv4/kernels/attention/fmha.py (~line 620) writes O via utils.gemm.sm100.epilogue_tma_store, which reads O straight from TMEM and TMA-stores it. Per the file header (point 2): it cannot accept flat_divide-based GMEM coordinates, so you can't fold (batch, head, kv_tile) into a multi-CTA grid. That single epilogue choice is why multi-tile has to be done host-side, and why D2/P4 (multi-CTA) and FP4 output fusion are all stuck behind it.

The TMEM round-trip ("D1.5 fundamentally broken") is a CuTeDSL atom-pairing limitation, not hardware — CUTLASS C++ pairs Ld/St atoms that work. So the answer is not "merge in Python," it's "use the one-way epilogue (no round-trip) or write the raw CUDA kernel."


The pattern to port FROM (already working on Blackwell)

dsv4/kernels/gemm/dense.py runs the correct one-way epilogue every MoE GEMM: TMEM → regs → SMEM → GMEM, no round-trip, multi-CTA-safe, with a register slot for fusion. Reuse these symbols (already imported into fmha.py):

  • sm100_utils.compute_epilogue_tile_shape (line ~313) — epi subtile
  • sm100_utils.make_smem_layout_epi (line ~361) — SMEM staging layout
  • epilogue_tmem_copy_and_partition / epilogue_smem_copy_and_partition (imported at fmha.py lines 7172, currently unused on the O path)
  • epilogue_op: cutlass.Constexpr = lambda x: x (line ~405) — the hook where FP4 pack lands later (NVFP4-1.2)

This is not new infrastructure. It's wiring the FMHA O-store to the epilogue the MoE kernel already proves out.


Design of the fix

1. Online softmax across KV tiles, in-kernel. Replace the host-side log-sum-exp merge with a running (row_max, row_sum) rescale inside the KV-tile loop — standard FlashAttention-2. The kernel already emits un-normalized O + LSE + row_sums for a single tile; lift the accumulation into the tile loop so one launch consumes all selected KV. No host merge, no per-tile sync, no per-tile allocs.

2. One-way epilogue. After the tile loop, O lives in TMEM. Copy TMEM→regs→SMEM→GMEM via the MoE epilogue path. No epilogue_tma_store, no TMEM round-trip. Normalize by the final row_sum in registers before the SMEM store (or keep emitting LSE if external normalization is still wanted — but it shouldn't be needed once the merge is in-kernel).

3. Multi-CTA grid. With the one-way epilogue accepting flat_divide coords, fold (batch, kv_head_group, [q_tile]) into the grid. This is D2/P4. use_2cta_instrs applies for prefill/batched shapes (M ≥ 256); decode single-token stays 1-CTA.

4. Preallocated scratch. Whatever staging remains is allocated once at warmup and reused (cudagraph-safe), never torch.zeros on the hot path.

Two ramps — pick per doctrine, not per convenience

  • Ramp A (preferred if CuTeDSL cooperates): finish the one-way epilogue in fmha_smem_acc.py against the MoE pattern. Status of that file is unclear ("many commits, unclear if working") — first action is to determine whether it works, with a print-and-diff, not by reading the diff log.
  • Ramp B (when CuTeDSL walls — the legitimate fallback): fmha_sm100_tc.cuh is already the right shape — tcgen05.mma SS for QK, TS for PV, TMEM accumulators, UMMA descriptors, and it already carries sRowMax/sRowSum running state for in-kernel online softmax. Finish THIS, not a Python merge. It is raw CUDA C++, it is Blackwell-native, and it has no epilogue_tma_store constraint because you control the store.

Test plan — measure, don't eyeball

  1. Correctness parity: new in-kernel path vs current Python-merge path on the same inputs across the matrix that already passes — hd ∈ {64,128,256,512}, n_segments ∈ {1,2,8}, with SWA mask / causal / sink / n_comp. Gate: cos ≥ 0.999998 (must match the path it replaces, this is a refactor not a numerics change).
  2. Launch + sync count: Nsight Systems (or a CUPTI launch counter) on one decoded token, V4-Pro CSA layer. Record launches and cudaDeviceSynchronize calls before/after. Target: per-layer launches from 32 → 1 (decode), syncs from 32 → 0 on the hot path.
  3. Latency: per-token decode latency at context 8k / 128k / 1M, before/after. This is the number that says whether Priority 1's "does the overhead matter" question is answered. (Spoiler from the launch math: it matters.)
  4. Unblock check: confirm the new epilogue accepts a multi-CTA grid (D2 smoke test, M≥256 prefill) and exposes a register slot for the FP4 pack (NVFP4-1.2 stub: epilogue_op that amax+packs, behind a flag, off by default).

DOCTRINE REMINDER (this issue is the cautionary tale)

  1. CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python. The Python KV merge is what rule 1 exists to prevent. fmha_sm100_tc.cuh is the correct fallback. A host-side loop with a per-iteration cudaDeviceSynchronize is never the production answer.
  2. Raw CUDA ≠ scalar math. The fallback kernel stays tcgen05/UMMA/TMEM/TMA with warp-level softmax reductions. Do not let an agent "simplify" the PV into scalar FMA to get it compiling.
  3. Print, don't guess. Before touching fmha_smem_acc.py, print its actual output and diff against the reference — do not infer its status from commit messages. When wiring the epilogue, print epi tile shape, SMEM layout, TMEM offsets, and the MMA instruction shape at construction, and code to those values. The epilogue is exactly the kind of layout surface where a guess passes a toy test and corrupts at real shapes.