Key fixes for fmha_epilogue_sm100.cuh hang:
- tcgen05.ld/st are WARP-COLLECTIVE: ALL 32 lanes must execute
- Old code guarded TMEM ops with if(tid==0) = warp divergence = HANG
- tmem_dealloc now uses tmem_base (value from alloc), not SMEM pointer
- Compute attention in SMEM, then do one-way TMEM pipeline:
SMEM → TMEM (warp-collective store) → regs (warp-collective load)
→ normalize in regs → BF16 cast → GMEM
- This proves the MoE-style one-way correction epilogue on FMHA
Also: enable TMEM kernel test + hd=128 in standalone test
ROOT CAUSE of TMET hang: tcgen05.fence.cta_group::1.sync.aligned is
NOT a valid PTX instruction. The correct TMEM ordering primitives are:
- tcgen05.wait::st.sync.aligned (wait for TMEM stores to complete)
- tcgen05.wait::ld.sync.aligned (wait for TMEM loads to complete)
Found in cutlass/arch/barrier.h fence_view_async_tmem_store/load.
What changed:
- Moved fmha_backup_pre_epilog.py, fmha_backup_v2.py, fmha_smem_acc.py to archive/
- Deleted fmha.py.backup (git has history)
- Added detailed heredoc headers to ALL files documenting:
* WHAT WORKS and WHAT'S BROKEN
* WHY each limitation exists (CuTeDSL toolchain gaps)
* KEY INSIGHTS FOR NVIDIA (what CuTeDSL is missing)
* What each file unblocks if fixed
File status:
fmha.py — CuTeDSL FMHA, cos 0.999998, D1.5 workaround
fmha_common.cuh — Raw CUDA shared defs (BF16, TMEM ops)
fmha_sm100.cuh — Raw CUDA reference, cos 0.999999
fmha_epilogue_sm100.cuh — Raw CUDA TMEM epilogue, HANGS (needs debug)
fmha_sm100_launch.cu — PyTorch binding (JIT broken, nvcc works)
production.py — CuTeDSL production wrapper (partial)
archive/ — Historical backups with explanation headers
New file: fmha_epilogue_sm100.cuh
- TMEM alloc/dealloc/load/store via tcgen05 PTX
- One-way correction epilogue: TMEM→regs→normalize→BF16→GMEM
- D1.5 fix: O rescale in REGISTERS (TMEM→regs→multiply→TMEM)
- Same pattern as MoE epilogue but with normalize instead of SwiGLU
- Unblocks D2 multi-CTA and NVFP4-1.2 (register slot for FP4 pack)
Test: hd=64 + hd=128, reference vs TMEM kernels
Use thread 0 for all computation (slow but correct).
SMEM for Q and O sharing across threads.
Online softmax with O rescale — correct D1.5 approach.
D3 SWA mask implemented.
Target: cos ~0.999998 then parallelize.
Simpler approach first: scalar Q@K^T, softmax, P@V in registers.
No TMEM/MMA yet — verify correctness first, then replace with tcgen05.
- 192-thread CTA, all threads cooperate on one (batch, head)
- Online softmax with O rescale (correct D1.5 approach)
- D3 SWA mask, D4 causal (TODO), D5c sink (TODO)
- KV loaded in blocks of 128 for SMEM efficiency
- Correctness target: cos ~0.999998 against PyTorch reference
- tcgen05.mma.cta_group::1.kind::f16 [tmem_c], desc_a, desc_b, idescE_hi, scaleC, {mask0..3}, pred
- idescE is upper 32 bits of the E descriptor
- scaleC is a float (1.0 for accumulate)
- mask is 4 uint32 values (0xFFFFFFFF for no masking)
CUTLASS headers transitively include cuda_bf16.h which has a CUDA 13.2
in_place_from bug. Writing tcgen05 PTX directly via inline asm instead.
No dependencies on CUTLASS C++ — pure PTX + CUDA runtime.