The 6-warp multi-tile kernel already supports batch natively via
dim3 grid(1, n_h, batch). Removed Python for-loop for 4D input.
Single kernel launch per layer for batched decode instead of
batch_size launches.
T>1 prefill still uses per-batch dispatch (E8 future work).
- indexer/__init__.py: compute_index_scores_topk now calls
run_indexer_score_topk with proper tensor reshaping
- compressor/__init__.py: added torch import, fixed csa_compress_tail
and hca_compress_tail imports for flush.py
- Full flush pipeline now importable end-to-end
Both gather_kv.cu and gather_swa.cu are compiled into one .so.
Only gather_kv.cu defines the PYBIND11_MODULE; gather_swa.cu
just provides the function implementations.
cp.async.bulk.tensor store (SMEM→GMEM) is NOT available on SM100.
The CUTLASS SM100 epilogue uses st.global directly.
The one-way epilogue pipeline is now:
1. TMEM → regs (tcgen05.ld, warp-collective)
2. epilogue_op in regs (normalize, FP4 hook via ENABLE_FP4_EPILOGUE)
3. regs → SMEM (row-major, sO_epi)
4. SMEM → GMEM (direct write)
This is the same pattern as the MoE kernel but with st.global instead
of TMA store. Multi-CTA (D2) will use st.global with flat_divide coords.
Removed: tma_o from FmhaParams, fmha_multihead_decode_tma_launch,
sMbarStore from SMEM, broken TMA store PTX from fmha_tma.cuh.
TMA store: cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes
Uses mbarrier for completion, not bulk_group. Restored sMbarStore to SMEM.
TMA store uses cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group
NOT mbarrier::complete_tx::bytes. Completion tracked via:
- cp.async.bulk.commit_group (after issuing stores)
- cp.async.bulk.wait_group.read 0 (wait for all groups)
Removed sMbarStore from SMEM allocations (no longer needed).
- fmha_multitile_capi.cu: C API wrapper for TMA multi-tile kernel
Creates TMA descriptors per (head, batch), launches kernel
- fmha_multitile_op.py: nvcc precompile + ctypes loader
- production.py: dispatch to multitile for N>128 or hd=512
- Reverted fmha_6warp_multihead.cuh to working single-tile version
- The TMA multi-tile kernel already passes 72 configs (D1.5)
HD=64/128/256/512 × T=1/4/32/128 × s_k=128/256/384/512
Single-tile path is IDENTICAL to the working pre-P5 kernel.
Multi-tile path uses FA2 online softmax with sOacc accumulator.
Runtime branch on is_multi_tile = (n_kv_tiles > 1).
- Kernel loops over KV tiles internally with running max/sum rescale
- SMEM accumulator sOacc[hd] replaces TMEM accumulation across tiles
- P is UN-NORMALIZED for multi-tile (exp(s-max), not /sum)
- Per KV tile: QK→softmax→PV→TMEM→read→add to sOacc
- Final: O = sOacc / running_sum
- Single tile (n_kv_tiles=1): same as before, no rescale
- Updated CAPI, Python loader, production.py fast path
- Added multi-tile test cases (N=256, 512)
P was already normalized in softmax step. PV = P_norm @ V gives the
correct attention output. Dividing by row_sum again in the epilogue
produces O = O_correct / row_sum (128x too small for uniform data).
- fmha_multihead_capi.cu: SMEM formula matches standalone test
Added cudaFuncSetAttribute for dynamic SMEM > 48KB
- fmha_multihead_op.py: pad K/V to N=128 when N<128
(kernel softmax loop is hardcoded to SK_TILE=128)
- Removed fmha_multihead_launch.cu (ATen approach, didn't work)
- Removed test_p3_ctypes_minimal.py (superseded by main test)
- fmha_multihead_capi.cu: pure C API wrapper, no ATen/pybind11 deps
- fmha_multihead_op.py: nvcc precompile + ctypes load (sm_100a)
- Removed fmha_multihead_launch.cu (ATen approach didn't work)
- Updated test to call kernel directly via ctypes API
- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
(c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
(dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)
Architecture:
Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
MQA: k_head_stride=0 so all Q heads share same K/V
Single kernel launch, zero cudaDeviceSynchronize on hot path
Normalized output for single-segment decode