Commit Graph

1695 Commits

Author SHA1 Message Date
9dfada6626 test: TMA + canonical + QK GEMM incremental 2026-05-29 19:28:23 +00:00
0435e229bd fix: typo cuda_SUCCESS -> cudaSuccess 2026-05-29 19:27:30 +00:00
74514e2680 test: TMA sub-tile load — exact pattern from test_qk_softmax 2026-05-29 19:26:56 +00:00
e449d6d5e1 test: TMA diagnostic with 192 threads 2026-05-29 19:26:09 +00:00
0b36b6047a test: TMA diagnostic with 128 threads 2026-05-29 19:25:38 +00:00
a766b488c2 test: minimal TMA diagnostic — isolate multi-warp TMA bug 2026-05-29 19:25:01 +00:00
fe3b6b8d13 test: QK+softmax T=1 first 2026-05-29 19:12:26 +00:00
a9a87fe7b8 fix: P write with lane stride, use sRowSum 2026-05-29 19:11:19 +00:00
fd6a9b00ae test: QK + softmax — verify P values against reference 2026-05-29 19:10:08 +00:00
5eff53c145 fix: SMEM layout and printf in PV-only test 2026-05-29 19:08:39 +00:00
106f103c83 test: PV-only GEMM — isolate PV from full FMHA pipeline 2026-05-29 19:06:52 +00:00
5542a9da00 debug: V loaded directly from GMEM (not TMA) to isolate PV issue 2026-05-29 18:57:42 +00:00
2262e10fca fix: PV GEMM — V canonical uses CORES_MN_V=2 (block_mn=16), not 16
V is the B operand with block_mn=16 in the PV MMA. Its canonical layout
uses CORES_MN=16/8=2, not 128/8=16. The previous code used CORES_MN=16
which produced wrong canonical indexing → garbage PV output.

Also:
- V SMEM size is (16,16) canonical = 256 BF16, not (128,16) = 2048
- P written as 16 elements at row 0 (T=1 decode)
- V loaded from TMA (16,128) and sub-sampled to (16,16) canonical
- V TMA coord: {col_start, d_base} for (HD,s_k) tensor
2026-05-29 18:54:02 +00:00
90c3372040 refactor: TMA FMHA kernel — 4-warp, proven pattern, full pipeline
Complete rewrite of fmha_6warp_tma.cuh based on lessons learned:
- 128 threads (4 warps) instead of 192 (6 warps) — simpler, proven
- Warp 0: TMA load + softmax, Warp 1: MMA + TMEM alloc
- TMA: mbarrier.arrive.expect_tx (root cause fix), phase parity tracking
- Q loaded directly (T=1 decode), K/V via TMA
- Per-K-sub-tile Q and K loading into (128,16) canonical buffers
- Full softmax + PV GEMM + epilogue pipeline
- Test updated to match new kernel signature
2026-05-29 18:50:58 +00:00
d5e20b2d42 fix: reference should be raw dot product (MMA is unscaled) 2026-05-29 18:48:39 +00:00
2b945f255b test: TMA K-load + QK GEMM — incremental from working pattern 2026-05-29 18:47:27 +00:00
f33746f183 test: minimal TMA K-load — no MMA/TMEM, just verify TMA + canonical 2026-05-29 18:46:09 +00:00
d64b62bc80 test: simple (128,16) TMA desc for K sub-tile only 2026-05-29 18:45:01 +00:00
eaf8a878cf fix: only warp 0 lane 0 issues TMA (not all lane 0 threads) 2026-05-29 18:44:18 +00:00
69bf20b09d fix: SMEM alignment in TMA K-only test 2026-05-29 18:43:44 +00:00
2c0ee69aea test: TMA K-only — proven gen pattern + TMA for K loads only 2026-05-29 18:43:07 +00:00
9fc2d549e4 fix: warp-collective TMEM read/dealloc in minimal QK test 2026-05-29 18:42:03 +00:00
c755e6fdde fix: TMEM read/dealloc for 128-thread kernel 2026-05-29 18:40:24 +00:00
bd1309ba88 test: minimal QK — 128 threads, tid==0 MMA, match working gen kernel pattern 2026-05-29 18:40:11 +00:00
39aef1284f fix: smem size in minimal QK test 2026-05-29 18:37:38 +00:00
ce89fe9170 test: minimal QK — separate sQ0/sK0, clean SMEM layout 2026-05-29 18:37:20 +00:00
71b353577d fix: QK direct test — per-K-sub-tile Q load (same as working kernel) 2026-05-29 18:35:00 +00:00
35d0596893 fix: T=1 for QK direct test (write_q_to_smem only handles row 0) 2026-05-29 18:33:35 +00:00
bee7cc5f8f fix: lane vs threadIdx.x in direct QK test 2026-05-29 18:32:21 +00:00
670599b754 test: direct QK GEMM — baseline for TMA comparison 2026-05-29 18:31:57 +00:00
9a185f0222 test: debug Q SMEM canonical after TMA load 2026-05-29 18:30:52 +00:00
1500020593 test: QK-only TMA test — isolate TMA load + canonical + MMA 2026-05-29 18:29:49 +00:00
204cc90808 fix: load full Q (128,HD) once before QK loop — not per K-sub-tile
The MMA expects Q sub-tiles from a full (128,HD) canonical buffer,
but we were only loading (128,16) sub-tiles into a (128,16) buffer.
The MMA descriptor with block_mn=128 describes a (128,128) matrix,
reading 128 columns from SMEM but only 16 had real data.

Now: load all HD/16 TMA tiles of Q into a full (128,HD) canonical
buffer before the QK loop. The MMA reads the kt-th sub-tile via
descriptor offset kt * 128 * 32 bytes.

Also: share single sTmaBuf staging buffer for all TMA loads (Q, K, V).
Removed separate sQ_tma, sK_tma, sV_tma buffers.
2026-05-29 18:28:45 +00:00
bf7cf54a51 fix: align TMA SMEM to 128 bytes in verification test 2026-05-29 18:27:07 +00:00
befc2c647b test: TMA load verification — compare against direct GMEM read 2026-05-29 18:26:34 +00:00
8e09fae3a1 fix: warp-stride for TMA canonical writes — only load warp calls them
write_smem_canonical used NTHREADS=192 as the stride, but in the TMA
kernel only the load warp (32 threads) calls it. With threadIdx.x in
[160,191] and stride 192, only 32 out of 2048 elements got written.
Fix: template STRIDE parameter, default 192, TMA kernel uses 32.
2026-05-29 18:25:47 +00:00
3e14a25bb0 fix: don't re-init mbarrier in loop — use phase parity tracking
The mbarrier is initialized once before the loop with count=1.
Inside the loop: issue TMA → arrive.expect_tx → wait(phase) → flip phase.
Re-initializing the mbarrier inside the loop resets the phase, which
breaks the parity tracking and causes the wait to hang.

This matches the CUTLASS/gau-nernst pattern exactly.
2026-05-29 18:24:47 +00:00
bd169ccb0f fix: smart quote in fmha_tma.cuh 2026-05-29 18:22:26 +00:00
345b107f4c fix: TMA mbarrier — add arrive.expect_tx (root cause of multi-warp hang)
The TMA cp.async.bulk.tensor with mbarrier::complete_tx::bytes decrements
the mbarrier tx_count by the byte count of the transfer. Without calling
mbarrier.arrive.expect_tx to increment tx_count first, the count underflows
and the phase never completes — causing the wait to hang forever.

This was the root cause of the multi-warp TMA hang. With 32 threads it
worked by accident (phase parity wrapped around); with 128+ threads the
timing was different and the hang was exposed.

Also:
- Use CUTLASS-style @P1 bra DONE wait pattern (not selp.b32)
- Add fence.mbarrier_init.release.cluster after mbarrier init
- Track phase parity across the kernel (flip after each wait)
- Re-init mbarrier before each TMA transaction (proper phase management)

Reference: gau-nernst tcgen05 tutorial
2026-05-29 18:22:00 +00:00
c69f3668e1 feat: TMA async FMHA kernel — WORKING on B200
Three critical CUDA 13 fixes that made TMA work:
1. globalStrides in BYTES not elements (root cause of desc creation failures)
2. BFLOAT16 data type instead of UINT16
3. mbarrier wait: selp.b32 polling pattern (@p bra HANGS on SM100!)

Also includes CUTLASS driver workaround (bit 21 clear for drv <= 13.1).

Verified: 2D TMA load of (128,16) BF16 tile = 0 mismatches.
Kernel: fmha_6warp_tma_kernel with per-sub-tile TMA loads for Q, K, V.
Test: test_fmha_tma.cu with padded Q allocations and per-head descriptors.
2026-05-29 07:02:07 +00:00
a40c05f3f2 archive: TMA driver-API files + CUDA 13 TMA discovery notes
Key findings documented in docs/cuda13_tma_notes.md:
- CUDA 13 globalStrides are in BYTES not elements (root cause of desc creation failures)
- BFLOAT16 data type available in CUDA 13
- Driver API descriptors create OK but cp.async.bulk.tensor hangs on driver 13.0 + toolkit 13.2
- CuTeDSL tma_partition works (production path)

Archived (not deleted):
- fmha_tma_driver_api.cuh, fmha_6warp_tma_driver_api.cuh, test_fmha_tma_driver_api.cu
- These will work once driver matches toolkit version
2026-05-29 06:52:39 +00:00
55f0c6267b auto: pre-test commit 2026-05-29 06:41:58 +00:00
197cac875c fix: CUDA 13 TMA descriptor — 3D rank + byte strides + mbarrier byte count
Three critical fixes for CUDA 13.x on Blackwell:
1. globalStrides are in BYTES not elements (CUDA 13 change)
2. Use 3D descriptors (degenerate 3rd dim=1) — CUDA 13 TMA requires rank >= 2
3. mbarrier init uses expected byte count (4096 for 128x16 BF16 tile)
4. cp.async.bulk.tensor.3d instead of .2d for 3D descriptors
5. BFLOAT16 data type instead of UINT16
2026-05-29 06:34:58 +00:00
85cd95e609 debug: TMA context fix test 2026-05-29 04:45:54 +00:00
76c82ebdcd debug: detailed TMA descriptor debug test 2026-05-29 04:45:06 +00:00
0c9245b4d2 fix: add cuInit(0) for CUDA driver API 2026-05-29 04:43:24 +00:00
6cc2f61431 debug: TMA descriptor dimension test 2026-05-29 04:42:44 +00:00
3412ff1a9b fix: TMA tile strides must match global strides, not tile dimensions
The tile stride in the outer dimension should be the global row stride
(cols), not the tile width. The tile is a window into the global tensor
and elements are addressed with global strides.
2026-05-29 04:41:53 +00:00
409838ace2 refactor: per-sub-tile TMA loads with padded GMEM allocations
- Q, K, V all loaded per (128,16) sub-tile via TMA
- Q GMEM padded to (128, HD) to satisfy TMA tile requirements
- Simpler SMEM layout — only (128,16) staging buffers needed
- Updated test with padded allocations
2026-05-29 04:41:03 +00:00
8c17f65f5b fix: cast typo 2026-05-29 04:39:21 +00:00