Commit Graph

438 Commits

Author SHA1 Message Date
d47b2bfcce fix: use un-normalized P for multi-tile PV (correct online softmax merge) 2026-05-29 19:57:54 +00:00
43ae3e7f98 fix: reload Q per-K-sub-tile in multi-tile kernel (same as single-tile) 2026-05-29 19:56:35 +00:00
8e99bd50e6 feat: 6-warp TMA multi-tile KV kernel with register accumulator + test 2026-05-29 19:49:53 +00:00
1814510195 wip: add n_kv_tiles param for multi-tile KV (not yet used) 2026-05-29 19:47:48 +00:00
d20792aa9d fix: TMA descriptor index for batched multi-head (batch*n_h + head) 2026-05-29 19:45:44 +00:00
754c6a692c feat: per-head TMA descriptors for multi-head FMHA 2026-05-29 19:44:58 +00:00
a1b2ab79a1 feat: 6-warp TMA FMHA multi-row kernel + test 2026-05-29 19:39:17 +00:00
d0a50f1f2e fix: remove double normalization in TMA epilogue (P already normalized before PV) 2026-05-29 19:36:41 +00:00
fb971781aa fix: revert V to direct load (V TMA needs debugging), K TMA works 2026-05-29 19:35:44 +00:00
cd2c028b39 feat: TMA loads for both K and V in 6-warp FMHA kernel 2026-05-29 19:34:48 +00:00
4459ddefdd feat: 6-warp TMA FMHA kernel + test — TMA for K loads 2026-05-29 19:32:02 +00:00
5542a9da00 debug: V loaded directly from GMEM (not TMA) to isolate PV issue 2026-05-29 18:57:42 +00:00
2262e10fca fix: PV GEMM — V canonical uses CORES_MN_V=2 (block_mn=16), not 16
V is the B operand with block_mn=16 in the PV MMA. Its canonical layout
uses CORES_MN=16/8=2, not 128/8=16. The previous code used CORES_MN=16
which produced wrong canonical indexing → garbage PV output.

Also:
- V SMEM size is (16,16) canonical = 256 BF16, not (128,16) = 2048
- P written as 16 elements at row 0 (T=1 decode)
- V loaded from TMA (16,128) and sub-sampled to (16,16) canonical
- V TMA coord: {col_start, d_base} for (HD,s_k) tensor
2026-05-29 18:54:02 +00:00
90c3372040 refactor: TMA FMHA kernel — 4-warp, proven pattern, full pipeline
Complete rewrite of fmha_6warp_tma.cuh based on lessons learned:
- 128 threads (4 warps) instead of 192 (6 warps) — simpler, proven
- Warp 0: TMA load + softmax, Warp 1: MMA + TMEM alloc
- TMA: mbarrier.arrive.expect_tx (root cause fix), phase parity tracking
- Q loaded directly (T=1 decode), K/V via TMA
- Per-K-sub-tile Q and K loading into (128,16) canonical buffers
- Full softmax + PV GEMM + epilogue pipeline
- Test updated to match new kernel signature
2026-05-29 18:50:58 +00:00
204cc90808 fix: load full Q (128,HD) once before QK loop — not per K-sub-tile
The MMA expects Q sub-tiles from a full (128,HD) canonical buffer,
but we were only loading (128,16) sub-tiles into a (128,16) buffer.
The MMA descriptor with block_mn=128 describes a (128,128) matrix,
reading 128 columns from SMEM but only 16 had real data.

Now: load all HD/16 TMA tiles of Q into a full (128,HD) canonical
buffer before the QK loop. The MMA reads the kt-th sub-tile via
descriptor offset kt * 128 * 32 bytes.

Also: share single sTmaBuf staging buffer for all TMA loads (Q, K, V).
Removed separate sQ_tma, sK_tma, sV_tma buffers.
2026-05-29 18:28:45 +00:00
8e09fae3a1 fix: warp-stride for TMA canonical writes — only load warp calls them
write_smem_canonical used NTHREADS=192 as the stride, but in the TMA
kernel only the load warp (32 threads) calls it. With threadIdx.x in
[160,191] and stride 192, only 32 out of 2048 elements got written.
Fix: template STRIDE parameter, default 192, TMA kernel uses 32.
2026-05-29 18:25:47 +00:00
3e14a25bb0 fix: don't re-init mbarrier in loop — use phase parity tracking
The mbarrier is initialized once before the loop with count=1.
Inside the loop: issue TMA → arrive.expect_tx → wait(phase) → flip phase.
Re-initializing the mbarrier inside the loop resets the phase, which
breaks the parity tracking and causes the wait to hang.

This matches the CUTLASS/gau-nernst pattern exactly.
2026-05-29 18:24:47 +00:00
bd169ccb0f fix: smart quote in fmha_tma.cuh 2026-05-29 18:22:26 +00:00
345b107f4c fix: TMA mbarrier — add arrive.expect_tx (root cause of multi-warp hang)
The TMA cp.async.bulk.tensor with mbarrier::complete_tx::bytes decrements
the mbarrier tx_count by the byte count of the transfer. Without calling
mbarrier.arrive.expect_tx to increment tx_count first, the count underflows
and the phase never completes — causing the wait to hang forever.

This was the root cause of the multi-warp TMA hang. With 32 threads it
worked by accident (phase parity wrapped around); with 128+ threads the
timing was different and the hang was exposed.

Also:
- Use CUTLASS-style @P1 bra DONE wait pattern (not selp.b32)
- Add fence.mbarrier_init.release.cluster after mbarrier init
- Track phase parity across the kernel (flip after each wait)
- Re-init mbarrier before each TMA transaction (proper phase management)

Reference: gau-nernst tcgen05 tutorial
2026-05-29 18:22:00 +00:00
c69f3668e1 feat: TMA async FMHA kernel — WORKING on B200
Three critical CUDA 13 fixes that made TMA work:
1. globalStrides in BYTES not elements (root cause of desc creation failures)
2. BFLOAT16 data type instead of UINT16
3. mbarrier wait: selp.b32 polling pattern (@p bra HANGS on SM100!)

Also includes CUTLASS driver workaround (bit 21 clear for drv <= 13.1).

Verified: 2D TMA load of (128,16) BF16 tile = 0 mismatches.
Kernel: fmha_6warp_tma_kernel with per-sub-tile TMA loads for Q, K, V.
Test: test_fmha_tma.cu with padded Q allocations and per-head descriptors.
2026-05-29 07:02:07 +00:00
a40c05f3f2 archive: TMA driver-API files + CUDA 13 TMA discovery notes
Key findings documented in docs/cuda13_tma_notes.md:
- CUDA 13 globalStrides are in BYTES not elements (root cause of desc creation failures)
- BFLOAT16 data type available in CUDA 13
- Driver API descriptors create OK but cp.async.bulk.tensor hangs on driver 13.0 + toolkit 13.2
- CuTeDSL tma_partition works (production path)

Archived (not deleted):
- fmha_tma_driver_api.cuh, fmha_6warp_tma_driver_api.cuh, test_fmha_tma_driver_api.cu
- These will work once driver matches toolkit version
2026-05-29 06:52:39 +00:00
197cac875c fix: CUDA 13 TMA descriptor — 3D rank + byte strides + mbarrier byte count
Three critical fixes for CUDA 13.x on Blackwell:
1. globalStrides are in BYTES not elements (CUDA 13 change)
2. Use 3D descriptors (degenerate 3rd dim=1) — CUDA 13 TMA requires rank >= 2
3. mbarrier init uses expected byte count (4096 for 128x16 BF16 tile)
4. cp.async.bulk.tensor.3d instead of .2d for 3D descriptors
5. BFLOAT16 data type instead of UINT16
2026-05-29 06:34:58 +00:00
3412ff1a9b fix: TMA tile strides must match global strides, not tile dimensions
The tile stride in the outer dimension should be the global row stride
(cols), not the tile width. The tile is a window into the global tensor
and elements are addressed with global strides.
2026-05-29 04:41:53 +00:00
409838ace2 refactor: per-sub-tile TMA loads with padded GMEM allocations
- Q, K, V all loaded per (128,16) sub-tile via TMA
- Q GMEM padded to (128, HD) to satisfy TMA tile requirements
- Simpler SMEM layout — only (128,16) staging buffers needed
- Updated test with padded allocations
2026-05-29 04:41:03 +00:00
b78ebe8a9c debug: add TMA descriptor error reporting 2026-05-29 04:38:57 +00:00
c7a6d7d231 fix: tma_mbar_init → tma_mbarrier_init (typo) 2026-05-29 04:37:48 +00:00
696462f07a feat: TMA async load infrastructure for FMHA kernel
- fmha_tma.cuh: TMA descriptor creation, mbarrier helpers, cp.async.bulk.tensor.2d wrappers
- fmha_6warp_tma.cuh: TMA-integrated multirow kernel with async GMEM→SMEM loads
  - TMA loads Q, K, V tiles to row-major SMEM
  - Transposes to canonical K-major layout for MMA
  - Same softmax/epilogue as non-TMA kernel
- test_fmha_tma.cu: Test harness for TMA FMHA (HD=64 first)
2026-05-29 04:36:52 +00:00
2f2259395e fix: always normalize in kernel, correct KV merge with normalized O + LSE 2026-05-28 23:53:44 +00:00
914f76d30c multirow: add normalize flag, un-norm + LSE output, multi-tile KV merge test 2026-05-28 23:51:23 +00:00
ac8fa779e2 fix: move epilogue TMEM loads outside my_row_active guard (warp-collective hang) 2026-05-28 23:46:46 +00:00
55c0604a71 add fence.sc.gpu between PV and epilogue for TMEM visibility 2026-05-28 23:21:53 +00:00
52809b0ec6 fix: tcgen05.wait::ld.sync.aligned (was missing 'sync') 2026-05-28 23:19:03 +00:00
468614a4e2 fmha_multirow: non-interleaved design — softmax first, then PV
KEY FIX: TMEM is shared between QK output (S) and PV output (O).
Cannot interleave softmax reads with PV writes because PV overwrites S.

New flow:
1. QK GEMM → S in TMEM
2. Softmax: read ALL S from TMEM, compute P in registers
   - Pass 1: row_max (4 warps, 32x32b.x8)
   - Pass 2: exp, sum, store P in p_vals[SK_TILE] registers
3. PV GEMM: write P to sPk per K-tile, accumulate O in TMEM
4. Epilogue: read O from TMEM, normalize, write GMEM

P in registers: each lane holds float p_vals[128] = 512 bytes.
Register budget: 128 lanes × 512B = 64KB (within B200 256KB register file).
2026-05-28 23:17:43 +00:00
43ba672e15 fmha_multirow: add fence.sc.gpu after QK GEMM for TMEM visibility 2026-05-28 23:13:31 +00:00
be2685e9e3 fmha_multirow: use natural 4-warp TMEM partitioning after UMMA
After UMMA (QK GEMM), 4 warps reading TMEM with 32x32b.x8 each
see a different 32-row partition (verified on B200):
  Warp 0 → rows 0-31, Warp 1 → rows 32-63, etc.
Lane l in warp w reads row w*32 + l.

This eliminates the broken row_page<<16 addressing and allows:
- T<=32: warp 0 only, 32x32b.x8, each lane = one row
- T>32: 4 warps, each reads its natural 32-row partition
- Epilogue: same partitioning for reading O from TMEM

No s_p_vals buffer. P streamed per K-tile through sPk.
2026-05-28 23:07:31 +00:00
863a030c3b fmha_multirow: rewrite with 32x32b.x8 only, no s_p_vals, row_page addressing
- Kill 64KB s_p_vals buffer — P is streamed per K-tile through sPk
- All TMEM ops use 32x32b.x8 exclusively (16x256b.x1 crashes on 2nd call)
- T>32: 4 softmax warps use row_page offset in TMEM address (row<<16)
- Lane l in warp w handles row w*32+l
- Two-pass softmax: pass 1 row_max, pass 2 exp/sum interleaved with PV
- PV: N=16 sub-tiles, SS MMA sPk(128,16) × sV(16,16) → TMEM
- Epilogue: 32x32b.x8 TMEM read, normalize, BF16 → GMEM
- SMEM budget: ~14KB (well within 232KB)
2026-05-28 22:52:52 +00:00
1ba304db3e stuff 2026-05-28 21:08:13 +00:00
deaa3ec725 CRITICAL FIX: Q/K SMEM canonical layout must use local d (0..15) not full_d — UMMA descriptor reads from sQ0/sK0 start, not offset 2026-05-28 20:13:52 +00:00
08694b8136 Fix multi-row softmax v3: 32x32b.x8 with per-lane per-row (no wmax/wsum), per-row sRowMax/sRowSum arrays 2026-05-28 20:10:13 +00:00
aaa76c1af1 Rewrite multi-row softmax using 16x256b.x1 TMEM reads for proper multi-row access 2026-05-28 20:08:30 +00:00
5e3c61184c Fix multi-row softmax: remove cross-lane wmax/wsum — each lane handles its own row independently 2026-05-28 20:06:16 +00:00
d8b421ccee Multi-row FMHA kernel (Milestone 4): T>1 prefill support with 4-warp parallel softmax 2026-05-28 20:04:29 +00:00
aa41cfa2e5 Multi-head FMHA kernel (Milestone 5): grid launch with MHA/MQA/batch support
- fmha_6warp_multihead.cuh: grid=(1, n_h, batch) kernel with FmhaParams
- MQA support via k_head_stride=0 / v_head_stride=0
- LSE output for multi-segment KV merge composition
- test_fmha_6warp_multihead.cu: MHA (4+8 heads), MQA, batched tests
- HD-specific wrappers for hd=16/64/128/256
- Marked E2M1 dequant bug as FIXED in consultant issue file
2026-05-28 19:32:35 +00:00
6af2feb42a TMA 5D test: element stride decomposition 2026-05-28 19:18:01 +00:00
41343fdc6b auto: pre-test commit 2026-05-28 19:08:04 +00:00
b3020c2811 6-warp specialized FMHA kernel — ALL HD=16/64/128/256 PASS cos 0.999997+
Warp layout (192 threads):
- Warps 0-3: Softmax + correction + epilogue
- Warp 4: MMA (QK + PV GEMM)
- Warp 5: Data staging (Q/K/V loads, direct GMEM for now)
CTA-wide __syncthreads() sync between phases.

Fix: removed spurious inv_sum normalization in epilogue
(MMA output is already correctly scaled with softmax'd P).

Files: fmha_6warp.cuh + test_fmha_6warp*.cu
2026-05-28 16:34:14 +00:00
2a6d72912a auto: pre-test commit 2026-05-28 16:28:58 +00:00
e74c84458c Clean up E2M1 dequant: use LUT approach (consultant recommendation)
Both indexer files now use a constexpr LUT matching Python's
E2M1_MAGNITUDES = [0, 0.5, 1, 1.5, 2, 3, 4, 6].
This is cleaner and more auditable than bit-manipulation.
2026-05-28 16:17:47 +00:00
79ef87f9a9 FIX: E2M1 FP4 dequantization bug in indexer_score_topk.cu
The dequant_fp4_scalar function was treating the magnitude bits as
a raw integer (0-6) instead of the E2M1 floating-point format:
  Old (WRONG): val = (int)(nibble & 0x07) * scale
  New (CORRECT): proper E2M1 decode with exponent + mantissa

E2M1 encoding (bias=1):
  exp=0 subnormal: 0b000=0, 0b001=0.5
  exp=1: 0b010=1, 0b011=1.5
  exp=2: 0b100=2, 0b101=3
  exp=3: 0b110=4, 0b111=6

Bug found by outside consultant. Affects indexer top-k selection
correctness — wrong FP4 key decoding would select wrong CSA blocks.

Fixed in both:
- dsv4/kernels/indexer/indexer_score_topk.cu
- dsv4/kernels/cuda/indexer_score_topk.cu
2026-05-28 16:16:24 +00:00
44c4bade5f Rewrite fmha_sm100_tc.cuh with working N=16 PV sub-tile approach
Production FMHA kernel template for Blackwell SM100:
- FmhaSm100Kernel<HD>::launch(q, k, v, o, s_k, scale, stream)
- QK: SS MMA N=128, one K-tile at a time
- PV: SS MMA N=16 sub-tiles (HD/16 calls per K-tile)
- Epilogue: TMEM → regs → BF16 → GMEM
- ~25KB SMEM for all HD values
- All HD=16/64/128/256 pass with cos 0.999997+
2026-05-28 16:04:11 +00:00