Canonical UMMA layout for SWIZZLE_NONE:
- MN-major (128, 64): LBO=16, SBO=128 (from logical_divide Tile(1,8))
- K-major (128, 64): LBO=16, SBO=32 (from logical_divide Tile(8,2))
Using simple row-major SMEM layout (no swizzle, no permutation).
Data is written directly to SMEM in row-major order.
The descriptor strides describe the canonical layout.
Proper implementation of the SMEM layout that tcgen05.mma expects:
- SWIZZLE_128B (layout_type=2) for both MN-major A and K-major B
- Swizzle<3,4,3> applied to element offsets before SMEM write
- MN_SW128 atom: (1024, 8) BF16, stride (1, 1024)
- K_SW128 atom: (8, 1024) BF16, stride (1, 8)
- umma_smem_write/read functions for both MN and K major
- Descriptor with correct leading_byte_offset and stride_byte_offset
This is the RIGHT WAY. No shortcuts.
The descriptor bitfield is completely different from what I assumed:
- [0,14) start_address (smem_ptr >> 4)
- [16,30) leading_byte_offset (row stride bytes >> 4)
- [32,46) stride_byte_offset
- [46,48) version = 1 (Blackwell)
- [61,64) layout_type (0=NONE, 2=128B, 4=64B, 6=32B)
- idescE = desc >> 32, passed as separate arg to MMA PTX
The 64-bit descriptor uses byte offsets (not log2 or element counts).
The upper 32 bits are reinterpreted by the MMA hardware as idescE.
Step 1 of tensor-core acceleration:
- fmha_umma_desc.cuh: UMMA SMEM descriptor construction (raw bitfield)
- fmha_qk_verify.cuh: QK GEMM using tcgen05.mma SS (SMEM A, SMEM B → TMEM C)
- test_qk_mma.cu: standalone test comparing MMA output vs CPU reference
Key design decisions:
- UMMA descriptors built from raw bitfield (no CuTe dependency)
- tcgen05.mma called by one lane per warp (elect_one_sync pattern)
- Q: (128, HD) MN-major, K: (128, HD) K-major (transposed via descriptor)
- S: (128, 128) in TMEM, row 0 read back via tcgen05.ld
Updated fmha_common.cuh, fmha_sm100.cuh, fmha_epilogue_sm100.cuh,
and fmha_sm100_launch.cuh with comprehensive here-docs explaining:
1. The 4 CuTeDSL gaps that forced us to raw CUDA C++:
- TMEM round-trip broken (Ld32x32bOp/St32x32bOp column mismatch)
- Float→int impossible (arith.fptosi not lowerable)
- epilogue_tma_store blocks multi-CTA
- hd=512 MLIR optimizer hangs
2. TMEM lane mapping (verified on B200):
- Lane i → positions i*4+0..3, 128 FP32 per column
- Warp-collective: ALL 32 lanes must call ld/st or HANG
- Column address = tmem_base + column_index
3. Key insight for NVIDIA: float→int gap is the single most
impactful limitation, blocking ALL quantization-epilogue fusion
Removed all dead code from the first (broken) attention loop approach.
Clean pipeline: SMEM attention → TMEM write → TMEM read → normalize → GMEM.
Also renamed sPvBuf to sO for clarity (same as reference kernel).
CRITICAL FIX: tcgen05.st 16x256b.x1.b32 is warp-collective where:
- Lane i writes to positions i*4+0..i*4+3 within the column
- 32 lanes × 4 FP32 = 128 FP32 per column
- For row 0: lane 0 = positions 0-3, lane 1 = 4-7, ..., lane 31 = 124-127
Old code iterated col = lane; col < N; col += 32, treating each lane
as owning a separate column. That was WRONG — all 32 lanes share each
column, each owning 4 positions within it.
New code: HD values need ceil(HD/128) columns. Lane i writes
sPvBuf[i*4+0..3] to column 0 (or column 1 for HD > 128).
Verified via test_tmem_lane_mapping.cu on B200.
Root cause of TMEM epilogue hang: tmem_store/tmem_load are
warp-collective operations requiring ALL 32 lanes to participate.
The loop 'for (col = lane; col < TMEM_O_COLS; col += WARP)' with
TMEM_O_COLS=16 and WARP=32 means only lanes 0-15 execute the op.
Lanes 16-31 skip it = warp divergence on collective = HANG.
Fix: loop over TMEM_N (>= 32, power of 2) so all 32 lanes
participate. Columns beyond TMEM_O_COLS write don't-care data
to allocated-but-unused TMEM columns.
Two new turnkey harness scripts for .cu tests:
- fire_b200_cuda_test: compile+run+poll, kills everything first,
deletes old logs, one test at a time, screen-based, timeout
- check_b200_cuda: peek at running test log, or kill hung test
README updated with CUDA harness documentation.
Removed janky tests/run_cuda_test.sh.
Key fixes for fmha_epilogue_sm100.cuh hang:
- tcgen05.ld/st are WARP-COLLECTIVE: ALL 32 lanes must execute
- Old code guarded TMEM ops with if(tid==0) = warp divergence = HANG
- tmem_dealloc now uses tmem_base (value from alloc), not SMEM pointer
- Compute attention in SMEM, then do one-way TMEM pipeline:
SMEM → TMEM (warp-collective store) → regs (warp-collective load)
→ normalize in regs → BF16 cast → GMEM
- This proves the MoE-style one-way correction epilogue on FMHA
Also: enable TMEM kernel test + hd=128 in standalone test
ROOT CAUSE of TMET hang: tcgen05.fence.cta_group::1.sync.aligned is
NOT a valid PTX instruction. The correct TMEM ordering primitives are:
- tcgen05.wait::st.sync.aligned (wait for TMEM stores to complete)
- tcgen05.wait::ld.sync.aligned (wait for TMEM loads to complete)
Found in cutlass/arch/barrier.h fence_view_async_tmem_store/load.
What changed:
- Moved fmha_backup_pre_epilog.py, fmha_backup_v2.py, fmha_smem_acc.py to archive/
- Deleted fmha.py.backup (git has history)
- Added detailed heredoc headers to ALL files documenting:
* WHAT WORKS and WHAT'S BROKEN
* WHY each limitation exists (CuTeDSL toolchain gaps)
* KEY INSIGHTS FOR NVIDIA (what CuTeDSL is missing)
* What each file unblocks if fixed
File status:
fmha.py — CuTeDSL FMHA, cos 0.999998, D1.5 workaround
fmha_common.cuh — Raw CUDA shared defs (BF16, TMEM ops)
fmha_sm100.cuh — Raw CUDA reference, cos 0.999999
fmha_epilogue_sm100.cuh — Raw CUDA TMEM epilogue, HANGS (needs debug)
fmha_sm100_launch.cu — PyTorch binding (JIT broken, nvcc works)
production.py — CuTeDSL production wrapper (partial)
archive/ — Historical backups with explanation headers
New file: fmha_epilogue_sm100.cuh
- TMEM alloc/dealloc/load/store via tcgen05 PTX
- One-way correction epilogue: TMEM→regs→normalize→BF16→GMEM
- D1.5 fix: O rescale in REGISTERS (TMEM→regs→multiply→TMEM)
- Same pattern as MoE epilogue but with normalize instead of SwiGLU
- Unblocks D2 multi-CTA and NVFP4-1.2 (register slot for FP4 pack)
Test: hd=64 + hd=128, reference vs TMEM kernels
Use thread 0 for all computation (slow but correct).
SMEM for Q and O sharing across threads.
Online softmax with O rescale — correct D1.5 approach.
D3 SWA mask implemented.
Target: cos ~0.999998 then parallelize.
Simpler approach first: scalar Q@K^T, softmax, P@V in registers.
No TMEM/MMA yet — verify correctness first, then replace with tcgen05.
- 192-thread CTA, all threads cooperate on one (batch, head)
- Online softmax with O rescale (correct D1.5 approach)
- D3 SWA mask, D4 causal (TODO), D5c sink (TODO)
- KV loaded in blocks of 128 for SMEM efficiency
- Correctness target: cos ~0.999998 against PyTorch reference
- tcgen05.mma.cta_group::1.kind::f16 [tmem_c], desc_a, desc_b, idescE_hi, scaleC, {mask0..3}, pred
- idescE is upper 32 bits of the E descriptor
- scaleC is a float (1.0 for accumulate)
- mask is 4 uint32 values (0xFFFFFFFF for no masking)