V is the B operand with block_mn=16 in the PV MMA. Its canonical layout
uses CORES_MN=16/8=2, not 128/8=16. The previous code used CORES_MN=16
which produced wrong canonical indexing → garbage PV output.
Also:
- V SMEM size is (16,16) canonical = 256 BF16, not (128,16) = 2048
- P written as 16 elements at row 0 (T=1 decode)
- V loaded from TMA (16,128) and sub-sampled to (16,16) canonical
- V TMA coord: {col_start, d_base} for (HD,s_k) tensor
The MMA expects Q sub-tiles from a full (128,HD) canonical buffer,
but we were only loading (128,16) sub-tiles into a (128,16) buffer.
The MMA descriptor with block_mn=128 describes a (128,128) matrix,
reading 128 columns from SMEM but only 16 had real data.
Now: load all HD/16 TMA tiles of Q into a full (128,HD) canonical
buffer before the QK loop. The MMA reads the kt-th sub-tile via
descriptor offset kt * 128 * 32 bytes.
Also: share single sTmaBuf staging buffer for all TMA loads (Q, K, V).
Removed separate sQ_tma, sK_tma, sV_tma buffers.
write_smem_canonical used NTHREADS=192 as the stride, but in the TMA
kernel only the load warp (32 threads) calls it. With threadIdx.x in
[160,191] and stride 192, only 32 out of 2048 elements got written.
Fix: template STRIDE parameter, default 192, TMA kernel uses 32.
The mbarrier is initialized once before the loop with count=1.
Inside the loop: issue TMA → arrive.expect_tx → wait(phase) → flip phase.
Re-initializing the mbarrier inside the loop resets the phase, which
breaks the parity tracking and causes the wait to hang.
This matches the CUTLASS/gau-nernst pattern exactly.
The TMA cp.async.bulk.tensor with mbarrier::complete_tx::bytes decrements
the mbarrier tx_count by the byte count of the transfer. Without calling
mbarrier.arrive.expect_tx to increment tx_count first, the count underflows
and the phase never completes — causing the wait to hang forever.
This was the root cause of the multi-warp TMA hang. With 32 threads it
worked by accident (phase parity wrapped around); with 128+ threads the
timing was different and the hang was exposed.
Also:
- Use CUTLASS-style @P1 bra DONE wait pattern (not selp.b32)
- Add fence.mbarrier_init.release.cluster after mbarrier init
- Track phase parity across the kernel (flip after each wait)
- Re-init mbarrier before each TMA transaction (proper phase management)
Reference: gau-nernst tcgen05 tutorial
Three critical CUDA 13 fixes that made TMA work:
1. globalStrides in BYTES not elements (root cause of desc creation failures)
2. BFLOAT16 data type instead of UINT16
3. mbarrier wait: selp.b32 polling pattern (@p bra HANGS on SM100!)
Also includes CUTLASS driver workaround (bit 21 clear for drv <= 13.1).
Verified: 2D TMA load of (128,16) BF16 tile = 0 mismatches.
Kernel: fmha_6warp_tma_kernel with per-sub-tile TMA loads for Q, K, V.
Test: test_fmha_tma.cu with padded Q allocations and per-head descriptors.
Key findings documented in docs/cuda13_tma_notes.md:
- CUDA 13 globalStrides are in BYTES not elements (root cause of desc creation failures)
- BFLOAT16 data type available in CUDA 13
- Driver API descriptors create OK but cp.async.bulk.tensor hangs on driver 13.0 + toolkit 13.2
- CuTeDSL tma_partition works (production path)
Archived (not deleted):
- fmha_tma_driver_api.cuh, fmha_6warp_tma_driver_api.cuh, test_fmha_tma_driver_api.cu
- These will work once driver matches toolkit version
Three critical fixes for CUDA 13.x on Blackwell:
1. globalStrides are in BYTES not elements (CUDA 13 change)
2. Use 3D descriptors (degenerate 3rd dim=1) — CUDA 13 TMA requires rank >= 2
3. mbarrier init uses expected byte count (4096 for 128x16 BF16 tile)
4. cp.async.bulk.tensor.3d instead of .2d for 3D descriptors
5. BFLOAT16 data type instead of UINT16