- test_p3_fast_decode.py: clean kernel test + full API test
- Removed debug tests (sanity, v_debug, v_ref_debug)
- Double normalization fix verified: kernel output matches reference
at cos >= 0.999990 across all MHA/MQA/GQA configs
P was already normalized in softmax step. PV = P_norm @ V gives the
correct attention output. Dividing by row_sum again in the epilogue
produces O = O_correct / row_sum (128x too small for uniform data).
- fmha_multihead_capi.cu: SMEM formula matches standalone test
Added cudaFuncSetAttribute for dynamic SMEM > 48KB
- fmha_multihead_op.py: pad K/V to N=128 when N<128
(kernel softmax loop is hardcoded to SK_TILE=128)
- Removed fmha_multihead_launch.cu (ATen approach, didn't work)
- Removed test_p3_ctypes_minimal.py (superseded by main test)
- fmha_multihead_capi.cu: pure C API wrapper, no ATen/pybind11 deps
- fmha_multihead_op.py: nvcc precompile + ctypes load (sm_100a)
- Removed fmha_multihead_launch.cu (ATen approach didn't work)
- Updated test to call kernel directly via ctypes API
- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
(c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
(dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)
Architecture:
Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
MQA: k_head_stride=0 so all Q heads share same K/V
Single kernel launch, zero cudaDeviceSynchronize on hot path
Normalized output for single-segment decode