- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
(c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
(dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)
Architecture:
Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
MQA: k_head_stride=0 so all Q heads share same K/V
Single kernel launch, zero cudaDeviceSynchronize on hot path
Normalized output for single-segment decode
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.
Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.
Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
Replace hand-constructed Ld32x32bOp/St32x32bOp TMEM round-trip with the
proven correction epilogue pattern from fused_swiglu.py:
1. O rescale (kt>0): TMEM→REGS (paired load), multiply by acc_scale,
REGS→TMEM (paired store via retile_to_S). No layout mismatch.
2. Final O output: One-way TMEM→REGS→SMEM→GMEM using
epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
+ TMA partition. Register-level normalization (divide by row_sum)
or raw BF16 cast for D5a path.
This fixes both D1.5 issues:
- Issue 1: TMEM round-trip corruption (hand-constructed atoms)
- Issue 2: O rescale for multi-KV-tile (kt>0)
Supports normalize=True (in-kernel) and normalize=False (D5a external).
Uses epilog_sync_bar + c_pipe for SMEM→GMEM, replacing epilogue_tma_store.