- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
(c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
(dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)
Architecture:
Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
MQA: k_head_stride=0 so all Q heads share same K/V
Single kernel launch, zero cudaDeviceSynchronize on hot path
Normalized output for single-segment decode
175 lines
8.8 KiB
Markdown
175 lines
8.8 KiB
Markdown
# ISSUE — Priority 2: kill the Python KV merge; one-way in-kernel epilogue + online softmax
|
||
|
||
**Status:** OPEN — structural shortcut on the FMHA decode hot path.
|
||
**Severity:** HIGH for "production grade." Correctness is fine (cos 0.999998); the
|
||
problem is this path can never be fast, and it blocks the entire multi-CTA / FP4
|
||
fusion chain (D2/P4, NVFP4-1.2, NVFP4-2).
|
||
**Scope:** `dsv4/kernels/attention/production.py` (`_run_fmha_segmented`),
|
||
`dsv4/kernels/attention/fmha.py` (epilogue), the WIP `fmha_smem_acc.py`, and the
|
||
raw-CUDA `fmha_sm100_tc.cuh`.
|
||
|
||
---
|
||
|
||
## TL;DR
|
||
|
||
Multi-KV-tile attention is currently done by launching the single-tile CuTeDSL
|
||
kernel once per (segment × pv_tile), copying un-normalized O + LSE back to the
|
||
host-side allocator, and merging tiles with **eager PyTorch ops** — with a full
|
||
**`torch.cuda.synchronize()` inside the inner loop**. That is orchestration, not a
|
||
kernel. The fix is the standard FlashAttention shape: **loop KV tiles inside the
|
||
kernel with running max/sum (online softmax), and write the result once through a
|
||
one-way TMEM→regs→SMEM→GMEM epilogue** — the exact pattern the MoE GEMM already
|
||
runs correctly on Blackwell.
|
||
|
||
This is the load-bearing example of doctrine rule 1: hitting a CuTeDSL wall is not
|
||
a license to fall back to Python.
|
||
|
||
---
|
||
|
||
## The shortcut, measured (not estimated)
|
||
|
||
`dsv4/kernels/attention/production.py`, `_run_fmha_segmented`:
|
||
|
||
```python
|
||
for seg in range(n_segments):
|
||
...
|
||
k_seg = torch.cat([k_seg, torch.zeros(...)]) # alloc + copy to pad
|
||
seg_o = torch.zeros(M, hd, ...) # per-seg allocs
|
||
for nt in range(n_pv_tiles):
|
||
c_tile = torch.zeros(M, pv_n_tile, 1, ...) # per-tile allocs
|
||
lse_tensor = torch.zeros(M, 1, 1, ...)
|
||
mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(...) # descriptor rebuild
|
||
... # every iteration
|
||
compiled(mQ, mK, mV, mC, stream, lse=mLSE, ...) # 1 launch
|
||
torch.cuda.synchronize() # <-- FULL DEVICE SYNC
|
||
seg_o[:, v_start:v_end] = c_tile[...].float()
|
||
# eager exp/log/div merge (≈5 launches)
|
||
o_accum = (e_old*o_accum + e_new*seg_o_norm) / e_sum
|
||
```
|
||
|
||
Per-call cost on the hot path, every iteration:
|
||
- 1 kernel launch **+ 1 full `cudaDeviceSynchronize`** (serializes the stream —
|
||
no overlap, no pipelining, no async)
|
||
- 6+ `torch.zeros` allocations (allocator churn)
|
||
- `torch.cat` to pad the partial tail segment (alloc + copy)
|
||
- `ct.from_dlpack(...).mark_layout_dynamic(...)` ×6 (host-side descriptor rebuild)
|
||
- ~5 eager merge ops per segment
|
||
|
||
**Launch/sync count, V4-Pro CSA decode, single token, single layer:**
|
||
`top_k = 1024` → `n_segments = 1024/128 = 8`; `hd = 512` → `n_pv_tiles = 512/128 = 4`.
|
||
→ **32 kernel launches + 32 device syncs** per CSA layer, plus ~40 eager merge ops.
|
||
Across the interleaved CSA/HCA layers of a 61-layer model that is **~1–2k launches
|
||
and ~1–2k device syncs per decoded token.** At single-digit-µs launch+sync each,
|
||
that's milliseconds of pure overhead per token — a hard ceiling of tens of tok/s
|
||
before any real compute runs. At 1M context (KV-read bound, the entire point of
|
||
V4) this is the bottleneck, full stop.
|
||
|
||
Head-packing (MQA → all 128 Q heads in M) is already correct and keeps this from
|
||
also multiplying by heads — good. The segment/tile loop and the per-iteration sync
|
||
are the problem.
|
||
|
||
---
|
||
|
||
## The blocker
|
||
|
||
`dsv4/kernels/attention/fmha.py` (~line 620) writes O via
|
||
`utils.gemm.sm100.epilogue_tma_store`, which reads O straight from TMEM and
|
||
TMA-stores it. Per the file header (point 2): it **cannot accept flat_divide-based
|
||
GMEM coordinates**, so you can't fold (batch, head, kv_tile) into a multi-CTA grid.
|
||
That single epilogue choice is why multi-tile has to be done host-side, and why
|
||
D2/P4 (multi-CTA) and FP4 output fusion are all stuck behind it.
|
||
|
||
The TMEM round-trip ("D1.5 fundamentally broken") is a **CuTeDSL** atom-pairing
|
||
limitation, not hardware — CUTLASS C++ pairs Ld/St atoms that work. So the answer
|
||
is not "merge in Python," it's "use the one-way epilogue (no round-trip) or write
|
||
the raw CUDA kernel."
|
||
|
||
---
|
||
|
||
## The pattern to port FROM (already working on Blackwell)
|
||
|
||
`dsv4/kernels/gemm/dense.py` runs the correct one-way epilogue every MoE GEMM:
|
||
TMEM → regs → SMEM → GMEM, no round-trip, multi-CTA-safe, with a register slot for
|
||
fusion. Reuse these symbols (already imported into `fmha.py`):
|
||
|
||
- `sm100_utils.compute_epilogue_tile_shape` (line ~313) — epi subtile
|
||
- `sm100_utils.make_smem_layout_epi` (line ~361) — SMEM staging layout
|
||
- `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition`
|
||
(imported at fmha.py lines 71–72, currently unused on the O path)
|
||
- `epilogue_op: cutlass.Constexpr = lambda x: x` (line ~405) — the hook where FP4
|
||
pack lands later (NVFP4-1.2)
|
||
|
||
This is not new infrastructure. It's wiring the FMHA O-store to the epilogue the
|
||
MoE kernel already proves out.
|
||
|
||
---
|
||
|
||
## Design of the fix
|
||
|
||
**1. Online softmax across KV tiles, in-kernel.** Replace the host-side log-sum-exp
|
||
merge with a running `(row_max, row_sum)` rescale inside the KV-tile loop — standard
|
||
FlashAttention-2. The kernel already emits un-normalized O + LSE + row_sums for a
|
||
single tile; lift the accumulation into the tile loop so one launch consumes all
|
||
selected KV. No host merge, no per-tile sync, no per-tile allocs.
|
||
|
||
**2. One-way epilogue.** After the tile loop, O lives in TMEM. Copy
|
||
TMEM→regs→SMEM→GMEM via the MoE epilogue path. No `epilogue_tma_store`, no TMEM
|
||
round-trip. Normalize by the final `row_sum` in registers before the SMEM store
|
||
(or keep emitting LSE if external normalization is still wanted — but it shouldn't
|
||
be needed once the merge is in-kernel).
|
||
|
||
**3. Multi-CTA grid.** With the one-way epilogue accepting flat_divide coords,
|
||
fold (batch, kv_head_group, [q_tile]) into the grid. This is D2/P4. `use_2cta_instrs`
|
||
applies for prefill/batched shapes (M ≥ 256); decode single-token stays 1-CTA.
|
||
|
||
**4. Preallocated scratch.** Whatever staging remains is allocated once at warmup
|
||
and reused (cudagraph-safe), never `torch.zeros` on the hot path.
|
||
|
||
### Two ramps — pick per doctrine, not per convenience
|
||
|
||
- **Ramp A (preferred if CuTeDSL cooperates):** finish the one-way epilogue in
|
||
`fmha_smem_acc.py` against the MoE pattern. Status of that file is unclear ("many
|
||
commits, unclear if working") — first action is to determine whether it works,
|
||
with a print-and-diff, not by reading the diff log.
|
||
- **Ramp B (when CuTeDSL walls — the legitimate fallback):** `fmha_sm100_tc.cuh`
|
||
is already the right shape — `tcgen05.mma` SS for QK, TS for PV, TMEM accumulators,
|
||
UMMA descriptors, and it already carries `sRowMax/sRowSum` running state for
|
||
in-kernel online softmax. Finish THIS, not a Python merge. It is raw CUDA C++,
|
||
it is Blackwell-native, and it has no `epilogue_tma_store` constraint because you
|
||
control the store.
|
||
|
||
---
|
||
|
||
## Test plan — measure, don't eyeball
|
||
|
||
1. **Correctness parity:** new in-kernel path vs current Python-merge path on the
|
||
same inputs across the matrix that already passes — hd ∈ {64,128,256,512},
|
||
n_segments ∈ {1,2,8}, with SWA mask / causal / sink / n_comp. Gate: cos ≥ 0.999998
|
||
(must match the path it replaces, this is a refactor not a numerics change).
|
||
2. **Launch + sync count:** Nsight Systems (or a CUPTI launch counter) on one
|
||
decoded token, V4-Pro CSA layer. Record launches and `cudaDeviceSynchronize`
|
||
calls before/after. Target: per-layer launches from 32 → 1 (decode), syncs from
|
||
32 → 0 on the hot path.
|
||
3. **Latency:** per-token decode latency at context 8k / 128k / 1M, before/after.
|
||
This is the number that says whether Priority 1's "does the overhead matter"
|
||
question is answered. (Spoiler from the launch math: it matters.)
|
||
4. **Unblock check:** confirm the new epilogue accepts a multi-CTA grid (D2 smoke
|
||
test, M≥256 prefill) and exposes a register slot for the FP4 pack (NVFP4-1.2
|
||
stub: `epilogue_op` that amax+packs, behind a flag, off by default).
|
||
|
||
---
|
||
|
||
## DOCTRINE REMINDER (this issue is the cautionary tale)
|
||
|
||
1. **CuTeDSL/CUTLASS wall → raw CUDA C++, NOT Python.** The Python KV merge is what
|
||
rule 1 exists to prevent. `fmha_sm100_tc.cuh` is the correct fallback. A host-side
|
||
loop with a per-iteration `cudaDeviceSynchronize` is never the production answer.
|
||
2. **Raw CUDA ≠ scalar math.** The fallback kernel stays `tcgen05`/UMMA/TMEM/TMA with
|
||
warp-level softmax reductions. Do not let an agent "simplify" the PV into scalar
|
||
FMA to get it compiling.
|
||
3. **Print, don't guess.** Before touching `fmha_smem_acc.py`, print its actual
|
||
output and diff against the reference — do not infer its status from commit
|
||
messages. When wiring the epilogue, print epi tile shape, SMEM layout, TMEM
|
||
offsets, and the MMA instruction shape at construction, and code to those values.
|
||
The epilogue is exactly the kind of layout surface where a guess passes a toy
|
||
test and corrupts at real shapes. |