Files
nvfp4-megamoe-kernel/archived_plans/PRIORITY2.md
biondizzle 1e6adf5e01 P3: wire 6-warp multi-head FMHA decode fast path into production.py
- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
  (c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
  (dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
  Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)

Architecture:
  Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
  MQA: k_head_stride=0 so all Q heads share same K/V
  Single kernel launch, zero cudaDeviceSynchronize on hot path
  Normalized output for single-segment decode
2026-05-30 08:12:23 +00:00

175 lines
8.8 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ISSUE — Priority 2: kill the Python KV merge; one-way in-kernel epilogue + online softmax
**Status:** OPEN — structural shortcut on the FMHA decode hot path.
**Severity:** HIGH for "production grade." Correctness is fine (cos 0.999998); the
problem is this path can never be fast, and it blocks the entire multi-CTA / FP4
fusion chain (D2/P4, NVFP4-1.2, NVFP4-2).
**Scope:** `dsv4/kernels/attention/production.py` (`_run_fmha_segmented`),
`dsv4/kernels/attention/fmha.py` (epilogue), the WIP `fmha_smem_acc.py`, and the
raw-CUDA `fmha_sm100_tc.cuh`.
---
## TL;DR
Multi-KV-tile attention is currently done by launching the single-tile CuTeDSL
kernel once per (segment × pv_tile), copying un-normalized O + LSE back to the
host-side allocator, and merging tiles with **eager PyTorch ops** — with a full
**`torch.cuda.synchronize()` inside the inner loop**. That is orchestration, not a
kernel. The fix is the standard FlashAttention shape: **loop KV tiles inside the
kernel with running max/sum (online softmax), and write the result once through a
one-way TMEM→regs→SMEM→GMEM epilogue** — the exact pattern the MoE GEMM already
runs correctly on Blackwell.
This is the load-bearing example of doctrine rule 1: hitting a CuTeDSL wall is not
a license to fall back to Python.
---
## The shortcut, measured (not estimated)
`dsv4/kernels/attention/production.py`, `_run_fmha_segmented`:
```python
for seg in range(n_segments):
...
k_seg = torch.cat([k_seg, torch.zeros(...)]) # alloc + copy to pad
seg_o = torch.zeros(M, hd, ...) # per-seg allocs
for nt in range(n_pv_tiles):
c_tile = torch.zeros(M, pv_n_tile, 1, ...) # per-tile allocs
lse_tensor = torch.zeros(M, 1, 1, ...)
mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(...) # descriptor rebuild
... # every iteration
compiled(mQ, mK, mV, mC, stream, lse=mLSE, ...) # 1 launch
torch.cuda.synchronize() # <-- FULL DEVICE SYNC
seg_o[:, v_start:v_end] = c_tile[...].float()
# eager exp/log/div merge (≈5 launches)
o_accum = (e_old*o_accum + e_new*seg_o_norm) / e_sum
```
Per-call cost on the hot path, every iteration:
- 1 kernel launch **+ 1 full `cudaDeviceSynchronize`** (serializes the stream —
no overlap, no pipelining, no async)
- 6+ `torch.zeros` allocations (allocator churn)
- `torch.cat` to pad the partial tail segment (alloc + copy)
- `ct.from_dlpack(...).mark_layout_dynamic(...)` ×6 (host-side descriptor rebuild)
- ~5 eager merge ops per segment
**Launch/sync count, V4-Pro CSA decode, single token, single layer:**
`top_k = 1024``n_segments = 1024/128 = 8`; `hd = 512``n_pv_tiles = 512/128 = 4`.
**32 kernel launches + 32 device syncs** per CSA layer, plus ~40 eager merge ops.
Across the interleaved CSA/HCA layers of a 61-layer model that is **~12k launches
and ~12k device syncs per decoded token.** At single-digit-µs launch+sync each,
that's milliseconds of pure overhead per token — a hard ceiling of tens of tok/s
before any real compute runs. At 1M context (KV-read bound, the entire point of
V4) this is the bottleneck, full stop.
Head-packing (MQA → all 128 Q heads in M) is already correct and keeps this from
also multiplying by heads — good. The segment/tile loop and the per-iteration sync
are the problem.
---
## The blocker
`dsv4/kernels/attention/fmha.py` (~line 620) writes O via
`utils.gemm.sm100.epilogue_tma_store`, which reads O straight from TMEM and
TMA-stores it. Per the file header (point 2): it **cannot accept flat_divide-based
GMEM coordinates**, so you can't fold (batch, head, kv_tile) into a multi-CTA grid.
That single epilogue choice is why multi-tile has to be done host-side, and why
D2/P4 (multi-CTA) and FP4 output fusion are all stuck behind it.
The TMEM round-trip ("D1.5 fundamentally broken") is a **CuTeDSL** atom-pairing
limitation, not hardware — CUTLASS C++ pairs Ld/St atoms that work. So the answer
is not "merge in Python," it's "use the one-way epilogue (no round-trip) or write
the raw CUDA kernel."
---
## The pattern to port FROM (already working on Blackwell)
`dsv4/kernels/gemm/dense.py` runs the correct one-way epilogue every MoE GEMM:
TMEM → regs → SMEM → GMEM, no round-trip, multi-CTA-safe, with a register slot for
fusion. Reuse these symbols (already imported into `fmha.py`):
- `sm100_utils.compute_epilogue_tile_shape` (line ~313) — epi subtile
- `sm100_utils.make_smem_layout_epi` (line ~361) — SMEM staging layout
- `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition`
(imported at fmha.py lines 7172, currently unused on the O path)
- `epilogue_op: cutlass.Constexpr = lambda x: x` (line ~405) — the hook where FP4
pack lands later (NVFP4-1.2)
This is not new infrastructure. It's wiring the FMHA O-store to the epilogue the
MoE kernel already proves out.
---
## Design of the fix
**1. Online softmax across KV tiles, in-kernel.** Replace the host-side log-sum-exp
merge with a running `(row_max, row_sum)` rescale inside the KV-tile loop — standard
FlashAttention-2. The kernel already emits un-normalized O + LSE + row_sums for a
single tile; lift the accumulation into the tile loop so one launch consumes all
selected KV. No host merge, no per-tile sync, no per-tile allocs.
**2. One-way epilogue.** After the tile loop, O lives in TMEM. Copy
TMEM→regs→SMEM→GMEM via the MoE epilogue path. No `epilogue_tma_store`, no TMEM
round-trip. Normalize by the final `row_sum` in registers before the SMEM store
(or keep emitting LSE if external normalization is still wanted — but it shouldn't
be needed once the merge is in-kernel).
**3. Multi-CTA grid.** With the one-way epilogue accepting flat_divide coords,
fold (batch, kv_head_group, [q_tile]) into the grid. This is D2/P4. `use_2cta_instrs`
applies for prefill/batched shapes (M ≥ 256); decode single-token stays 1-CTA.
**4. Preallocated scratch.** Whatever staging remains is allocated once at warmup
and reused (cudagraph-safe), never `torch.zeros` on the hot path.
### Two ramps — pick per doctrine, not per convenience
- **Ramp A (preferred if CuTeDSL cooperates):** finish the one-way epilogue in
`fmha_smem_acc.py` against the MoE pattern. Status of that file is unclear ("many
commits, unclear if working") — first action is to determine whether it works,
with a print-and-diff, not by reading the diff log.
- **Ramp B (when CuTeDSL walls — the legitimate fallback):** `fmha_sm100_tc.cuh`
is already the right shape — `tcgen05.mma` SS for QK, TS for PV, TMEM accumulators,
UMMA descriptors, and it already carries `sRowMax/sRowSum` running state for
in-kernel online softmax. Finish THIS, not a Python merge. It is raw CUDA C++,
it is Blackwell-native, and it has no `epilogue_tma_store` constraint because you
control the store.
---
## Test plan — measure, don't eyeball
1. **Correctness parity:** new in-kernel path vs current Python-merge path on the
same inputs across the matrix that already passes — hd ∈ {64,128,256,512},
n_segments ∈ {1,2,8}, with SWA mask / causal / sink / n_comp. Gate: cos ≥ 0.999998
(must match the path it replaces, this is a refactor not a numerics change).
2. **Launch + sync count:** Nsight Systems (or a CUPTI launch counter) on one
decoded token, V4-Pro CSA layer. Record launches and `cudaDeviceSynchronize`
calls before/after. Target: per-layer launches from 32 → 1 (decode), syncs from
32 → 0 on the hot path.
3. **Latency:** per-token decode latency at context 8k / 128k / 1M, before/after.
This is the number that says whether Priority 1's "does the overhead matter"
question is answered. (Spoiler from the launch math: it matters.)
4. **Unblock check:** confirm the new epilogue accepts a multi-CTA grid (D2 smoke
test, M≥256 prefill) and exposes a register slot for the FP4 pack (NVFP4-1.2
stub: `epilogue_op` that amax+packs, behind a flag, off by default).
---
## DOCTRINE REMINDER (this issue is the cautionary tale)
1. **CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python.** The Python KV merge is what
rule 1 exists to prevent. `fmha_sm100_tc.cuh` is the correct fallback. A host-side
loop with a per-iteration `cudaDeviceSynchronize` is never the production answer.
2. **Raw CUDA ≠ scalar math.** The fallback kernel stays `tcgen05`/UMMA/TMEM/TMA with
warp-level softmax reductions. Do not let an agent "simplify" the PV into scalar
FMA to get it compiling.
3. **Print, don't guess.** Before touching `fmha_smem_acc.py`, print its actual
output and diff against the reference — do not infer its status from commit
messages. When wiring the epilogue, print epi tile shape, SMEM layout, TMEM
offsets, and the MMA instruction shape at construction, and code to those values.
The epilogue is exactly the kind of layout surface where a guess passes a toy
test and corrupts at real shapes.