The MMA expects Q sub-tiles from a full (128,HD) canonical buffer,
but we were only loading (128,16) sub-tiles into a (128,16) buffer.
The MMA descriptor with block_mn=128 describes a (128,128) matrix,
reading 128 columns from SMEM but only 16 had real data.
Now: load all HD/16 TMA tiles of Q into a full (128,HD) canonical
buffer before the QK loop. The MMA reads the kt-th sub-tile via
descriptor offset kt * 128 * 32 bytes.
Also: share single sTmaBuf staging buffer for all TMA loads (Q, K, V).
Removed separate sQ_tma, sK_tma, sV_tma buffers.
write_smem_canonical used NTHREADS=192 as the stride, but in the TMA
kernel only the load warp (32 threads) calls it. With threadIdx.x in
[160,191] and stride 192, only 32 out of 2048 elements got written.
Fix: template STRIDE parameter, default 192, TMA kernel uses 32.
The mbarrier is initialized once before the loop with count=1.
Inside the loop: issue TMA → arrive.expect_tx → wait(phase) → flip phase.
Re-initializing the mbarrier inside the loop resets the phase, which
breaks the parity tracking and causes the wait to hang.
This matches the CUTLASS/gau-nernst pattern exactly.
The TMA cp.async.bulk.tensor with mbarrier::complete_tx::bytes decrements
the mbarrier tx_count by the byte count of the transfer. Without calling
mbarrier.arrive.expect_tx to increment tx_count first, the count underflows
and the phase never completes — causing the wait to hang forever.
This was the root cause of the multi-warp TMA hang. With 32 threads it
worked by accident (phase parity wrapped around); with 128+ threads the
timing was different and the hang was exposed.
Also:
- Use CUTLASS-style @P1 bra DONE wait pattern (not selp.b32)
- Add fence.mbarrier_init.release.cluster after mbarrier init
- Track phase parity across the kernel (flip after each wait)
- Re-init mbarrier before each TMA transaction (proper phase management)
Reference: gau-nernst tcgen05 tutorial
Three critical CUDA 13 fixes that made TMA work:
1. globalStrides in BYTES not elements (root cause of desc creation failures)
2. BFLOAT16 data type instead of UINT16
3. mbarrier wait: selp.b32 polling pattern (@p bra HANGS on SM100!)
Also includes CUTLASS driver workaround (bit 21 clear for drv <= 13.1).
Verified: 2D TMA load of (128,16) BF16 tile = 0 mismatches.
Kernel: fmha_6warp_tma_kernel with per-sub-tile TMA loads for Q, K, V.
Test: test_fmha_tma.cu with padded Q allocations and per-head descriptors.
Key findings documented in docs/cuda13_tma_notes.md:
- CUDA 13 globalStrides are in BYTES not elements (root cause of desc creation failures)
- BFLOAT16 data type available in CUDA 13
- Driver API descriptors create OK but cp.async.bulk.tensor hangs on driver 13.0 + toolkit 13.2
- CuTeDSL tma_partition works (production path)
Archived (not deleted):
- fmha_tma_driver_api.cuh, fmha_6warp_tma_driver_api.cuh, test_fmha_tma_driver_api.cu
- These will work once driver matches toolkit version
Three critical fixes for CUDA 13.x on Blackwell:
1. globalStrides are in BYTES not elements (CUDA 13 change)
2. Use 3D descriptors (degenerate 3rd dim=1) — CUDA 13 TMA requires rank >= 2
3. mbarrier init uses expected byte count (4096 for 128x16 BF16 tile)
4. cp.async.bulk.tensor.3d instead of .2d for 3D descriptors
5. BFLOAT16 data type instead of UINT16
The tile stride in the outer dimension should be the global row stride
(cols), not the tile width. The tile is a window into the global tensor
and elements are addressed with global strides.
- Q, K, V all loaded per (128,16) sub-tile via TMA
- Q GMEM padded to (128, HD) to satisfy TMA tile requirements
- Simpler SMEM layout — only (128,16) staging buffers needed
- Updated test with padded allocations
KEY FIX: TMEM is shared between QK output (S) and PV output (O).
Cannot interleave softmax reads with PV writes because PV overwrites S.
New flow:
1. QK GEMM → S in TMEM
2. Softmax: read ALL S from TMEM, compute P in registers
- Pass 1: row_max (4 warps, 32x32b.x8)
- Pass 2: exp, sum, store P in p_vals[SK_TILE] registers
3. PV GEMM: write P to sPk per K-tile, accumulate O in TMEM
4. Epilogue: read O from TMEM, normalize, write GMEM
P in registers: each lane holds float p_vals[128] = 512 bytes.
Register budget: 128 lanes × 512B = 64KB (within B200 256KB register file).