Commit Graph

1699 Commits

Author SHA1 Message Date
bd4f09d514 fix: ambiguous MMA_K_BF16 in test 2026-05-29 19:32:15 +00:00
4459ddefdd feat: 6-warp TMA FMHA kernel + test — TMA for K loads 2026-05-29 19:32:02 +00:00
7a8ba8eeb6 fix: SMEM size calculation — TILE_SZ is in BF16 elements, need *sizeof(bf16_t) for bytes 2026-05-29 19:30:50 +00:00
aac1b25442 test: TMA QK diagnostic — 3 variants to isolate failure 2026-05-29 19:29:35 +00:00
9dfada6626 test: TMA + canonical + QK GEMM incremental 2026-05-29 19:28:23 +00:00
0435e229bd fix: typo cuda_SUCCESS -> cudaSuccess 2026-05-29 19:27:30 +00:00
74514e2680 test: TMA sub-tile load — exact pattern from test_qk_softmax 2026-05-29 19:26:56 +00:00
e449d6d5e1 test: TMA diagnostic with 192 threads 2026-05-29 19:26:09 +00:00
0b36b6047a test: TMA diagnostic with 128 threads 2026-05-29 19:25:38 +00:00
a766b488c2 test: minimal TMA diagnostic — isolate multi-warp TMA bug 2026-05-29 19:25:01 +00:00
fe3b6b8d13 test: QK+softmax T=1 first 2026-05-29 19:12:26 +00:00
a9a87fe7b8 fix: P write with lane stride, use sRowSum 2026-05-29 19:11:19 +00:00
fd6a9b00ae test: QK + softmax — verify P values against reference 2026-05-29 19:10:08 +00:00
5eff53c145 fix: SMEM layout and printf in PV-only test 2026-05-29 19:08:39 +00:00
106f103c83 test: PV-only GEMM — isolate PV from full FMHA pipeline 2026-05-29 19:06:52 +00:00
5542a9da00 debug: V loaded directly from GMEM (not TMA) to isolate PV issue 2026-05-29 18:57:42 +00:00
2262e10fca fix: PV GEMM — V canonical uses CORES_MN_V=2 (block_mn=16), not 16
V is the B operand with block_mn=16 in the PV MMA. Its canonical layout
uses CORES_MN=16/8=2, not 128/8=16. The previous code used CORES_MN=16
which produced wrong canonical indexing → garbage PV output.

Also:
- V SMEM size is (16,16) canonical = 256 BF16, not (128,16) = 2048
- P written as 16 elements at row 0 (T=1 decode)
- V loaded from TMA (16,128) and sub-sampled to (16,16) canonical
- V TMA coord: {col_start, d_base} for (HD,s_k) tensor
2026-05-29 18:54:02 +00:00
90c3372040 refactor: TMA FMHA kernel — 4-warp, proven pattern, full pipeline
Complete rewrite of fmha_6warp_tma.cuh based on lessons learned:
- 128 threads (4 warps) instead of 192 (6 warps) — simpler, proven
- Warp 0: TMA load + softmax, Warp 1: MMA + TMEM alloc
- TMA: mbarrier.arrive.expect_tx (root cause fix), phase parity tracking
- Q loaded directly (T=1 decode), K/V via TMA
- Per-K-sub-tile Q and K loading into (128,16) canonical buffers
- Full softmax + PV GEMM + epilogue pipeline
- Test updated to match new kernel signature
2026-05-29 18:50:58 +00:00
d5e20b2d42 fix: reference should be raw dot product (MMA is unscaled) 2026-05-29 18:48:39 +00:00
2b945f255b test: TMA K-load + QK GEMM — incremental from working pattern 2026-05-29 18:47:27 +00:00
f33746f183 test: minimal TMA K-load — no MMA/TMEM, just verify TMA + canonical 2026-05-29 18:46:09 +00:00
d64b62bc80 test: simple (128,16) TMA desc for K sub-tile only 2026-05-29 18:45:01 +00:00
eaf8a878cf fix: only warp 0 lane 0 issues TMA (not all lane 0 threads) 2026-05-29 18:44:18 +00:00
69bf20b09d fix: SMEM alignment in TMA K-only test 2026-05-29 18:43:44 +00:00
2c0ee69aea test: TMA K-only — proven gen pattern + TMA for K loads only 2026-05-29 18:43:07 +00:00
9fc2d549e4 fix: warp-collective TMEM read/dealloc in minimal QK test 2026-05-29 18:42:03 +00:00
c755e6fdde fix: TMEM read/dealloc for 128-thread kernel 2026-05-29 18:40:24 +00:00
bd1309ba88 test: minimal QK — 128 threads, tid==0 MMA, match working gen kernel pattern 2026-05-29 18:40:11 +00:00
39aef1284f fix: smem size in minimal QK test 2026-05-29 18:37:38 +00:00
ce89fe9170 test: minimal QK — separate sQ0/sK0, clean SMEM layout 2026-05-29 18:37:20 +00:00
71b353577d fix: QK direct test — per-K-sub-tile Q load (same as working kernel) 2026-05-29 18:35:00 +00:00
35d0596893 fix: T=1 for QK direct test (write_q_to_smem only handles row 0) 2026-05-29 18:33:35 +00:00
bee7cc5f8f fix: lane vs threadIdx.x in direct QK test 2026-05-29 18:32:21 +00:00
670599b754 test: direct QK GEMM — baseline for TMA comparison 2026-05-29 18:31:57 +00:00
9a185f0222 test: debug Q SMEM canonical after TMA load 2026-05-29 18:30:52 +00:00
1500020593 test: QK-only TMA test — isolate TMA load + canonical + MMA 2026-05-29 18:29:49 +00:00
204cc90808 fix: load full Q (128,HD) once before QK loop — not per K-sub-tile
The MMA expects Q sub-tiles from a full (128,HD) canonical buffer,
but we were only loading (128,16) sub-tiles into a (128,16) buffer.
The MMA descriptor with block_mn=128 describes a (128,128) matrix,
reading 128 columns from SMEM but only 16 had real data.

Now: load all HD/16 TMA tiles of Q into a full (128,HD) canonical
buffer before the QK loop. The MMA reads the kt-th sub-tile via
descriptor offset kt * 128 * 32 bytes.

Also: share single sTmaBuf staging buffer for all TMA loads (Q, K, V).
Removed separate sQ_tma, sK_tma, sV_tma buffers.
2026-05-29 18:28:45 +00:00
bf7cf54a51 fix: align TMA SMEM to 128 bytes in verification test 2026-05-29 18:27:07 +00:00
befc2c647b test: TMA load verification — compare against direct GMEM read 2026-05-29 18:26:34 +00:00
8e09fae3a1 fix: warp-stride for TMA canonical writes — only load warp calls them
write_smem_canonical used NTHREADS=192 as the stride, but in the TMA
kernel only the load warp (32 threads) calls it. With threadIdx.x in
[160,191] and stride 192, only 32 out of 2048 elements got written.
Fix: template STRIDE parameter, default 192, TMA kernel uses 32.
2026-05-29 18:25:47 +00:00
3e14a25bb0 fix: don't re-init mbarrier in loop — use phase parity tracking
The mbarrier is initialized once before the loop with count=1.
Inside the loop: issue TMA → arrive.expect_tx → wait(phase) → flip phase.
Re-initializing the mbarrier inside the loop resets the phase, which
breaks the parity tracking and causes the wait to hang.

This matches the CUTLASS/gau-nernst pattern exactly.
2026-05-29 18:24:47 +00:00
bd169ccb0f fix: smart quote in fmha_tma.cuh 2026-05-29 18:22:26 +00:00
345b107f4c fix: TMA mbarrier — add arrive.expect_tx (root cause of multi-warp hang)
The TMA cp.async.bulk.tensor with mbarrier::complete_tx::bytes decrements
the mbarrier tx_count by the byte count of the transfer. Without calling
mbarrier.arrive.expect_tx to increment tx_count first, the count underflows
and the phase never completes — causing the wait to hang forever.

This was the root cause of the multi-warp TMA hang. With 32 threads it
worked by accident (phase parity wrapped around); with 128+ threads the
timing was different and the hang was exposed.

Also:
- Use CUTLASS-style @P1 bra DONE wait pattern (not selp.b32)
- Add fence.mbarrier_init.release.cluster after mbarrier init
- Track phase parity across the kernel (flip after each wait)
- Re-init mbarrier before each TMA transaction (proper phase management)

Reference: gau-nernst tcgen05 tutorial
2026-05-29 18:22:00 +00:00
c69f3668e1 feat: TMA async FMHA kernel — WORKING on B200
Three critical CUDA 13 fixes that made TMA work:
1. globalStrides in BYTES not elements (root cause of desc creation failures)
2. BFLOAT16 data type instead of UINT16
3. mbarrier wait: selp.b32 polling pattern (@p bra HANGS on SM100!)

Also includes CUTLASS driver workaround (bit 21 clear for drv <= 13.1).

Verified: 2D TMA load of (128,16) BF16 tile = 0 mismatches.
Kernel: fmha_6warp_tma_kernel with per-sub-tile TMA loads for Q, K, V.
Test: test_fmha_tma.cu with padded Q allocations and per-head descriptors.
2026-05-29 07:02:07 +00:00
a40c05f3f2 archive: TMA driver-API files + CUDA 13 TMA discovery notes
Key findings documented in docs/cuda13_tma_notes.md:
- CUDA 13 globalStrides are in BYTES not elements (root cause of desc creation failures)
- BFLOAT16 data type available in CUDA 13
- Driver API descriptors create OK but cp.async.bulk.tensor hangs on driver 13.0 + toolkit 13.2
- CuTeDSL tma_partition works (production path)

Archived (not deleted):
- fmha_tma_driver_api.cuh, fmha_6warp_tma_driver_api.cuh, test_fmha_tma_driver_api.cu
- These will work once driver matches toolkit version
2026-05-29 06:52:39 +00:00
55f0c6267b auto: pre-test commit 2026-05-29 06:41:58 +00:00
197cac875c fix: CUDA 13 TMA descriptor — 3D rank + byte strides + mbarrier byte count
Three critical fixes for CUDA 13.x on Blackwell:
1. globalStrides are in BYTES not elements (CUDA 13 change)
2. Use 3D descriptors (degenerate 3rd dim=1) — CUDA 13 TMA requires rank >= 2
3. mbarrier init uses expected byte count (4096 for 128x16 BF16 tile)
4. cp.async.bulk.tensor.3d instead of .2d for 3D descriptors
5. BFLOAT16 data type instead of UINT16
2026-05-29 06:34:58 +00:00
85cd95e609 debug: TMA context fix test 2026-05-29 04:45:54 +00:00
76c82ebdcd debug: detailed TMA descriptor debug test 2026-05-29 04:45:06 +00:00
0c9245b4d2 fix: add cuInit(0) for CUDA driver API 2026-05-29 04:43:24 +00:00