Commit Graph

873 Commits

Author SHA1 Message Date
bf7cf54a51 fix: align TMA SMEM to 128 bytes in verification test 2026-05-29 18:27:07 +00:00
befc2c647b test: TMA load verification — compare against direct GMEM read 2026-05-29 18:26:34 +00:00
c69f3668e1 feat: TMA async FMHA kernel — WORKING on B200
Three critical CUDA 13 fixes that made TMA work:
1. globalStrides in BYTES not elements (root cause of desc creation failures)
2. BFLOAT16 data type instead of UINT16
3. mbarrier wait: selp.b32 polling pattern (@p bra HANGS on SM100!)

Also includes CUTLASS driver workaround (bit 21 clear for drv <= 13.1).

Verified: 2D TMA load of (128,16) BF16 tile = 0 mismatches.
Kernel: fmha_6warp_tma_kernel with per-sub-tile TMA loads for Q, K, V.
Test: test_fmha_tma.cu with padded Q allocations and per-head descriptors.
2026-05-29 07:02:07 +00:00
a40c05f3f2 archive: TMA driver-API files + CUDA 13 TMA discovery notes
Key findings documented in docs/cuda13_tma_notes.md:
- CUDA 13 globalStrides are in BYTES not elements (root cause of desc creation failures)
- BFLOAT16 data type available in CUDA 13
- Driver API descriptors create OK but cp.async.bulk.tensor hangs on driver 13.0 + toolkit 13.2
- CuTeDSL tma_partition works (production path)

Archived (not deleted):
- fmha_tma_driver_api.cuh, fmha_6warp_tma_driver_api.cuh, test_fmha_tma_driver_api.cu
- These will work once driver matches toolkit version
2026-05-29 06:52:39 +00:00
85cd95e609 debug: TMA context fix test 2026-05-29 04:45:54 +00:00
76c82ebdcd debug: detailed TMA descriptor debug test 2026-05-29 04:45:06 +00:00
0c9245b4d2 fix: add cuInit(0) for CUDA driver API 2026-05-29 04:43:24 +00:00
6cc2f61431 debug: TMA descriptor dimension test 2026-05-29 04:42:44 +00:00
409838ace2 refactor: per-sub-tile TMA loads with padded GMEM allocations
- Q, K, V all loaded per (128,16) sub-tile via TMA
- Q GMEM padded to (128, HD) to satisfy TMA tile requirements
- Simpler SMEM layout — only (128,16) staging buffers needed
- Updated test with padded allocations
2026-05-29 04:41:03 +00:00
8c17f65f5b fix: cast typo 2026-05-29 04:39:21 +00:00
8908b697dd fix: bool type mismatch 2026-05-29 04:39:12 +00:00
b78ebe8a9c debug: add TMA descriptor error reporting 2026-05-29 04:38:57 +00:00
696462f07a feat: TMA async load infrastructure for FMHA kernel
- fmha_tma.cuh: TMA descriptor creation, mbarrier helpers, cp.async.bulk.tensor.2d wrappers
- fmha_6warp_tma.cuh: TMA-integrated multirow kernel with async GMEM→SMEM loads
  - TMA loads Q, K, V tiles to row-major SMEM
  - Transposes to canonical K-major layout for MMA
  - Same softmax/epilogue as non-TMA kernel
- test_fmha_tma.cu: Test harness for TMA FMHA (HD=64 first)
2026-05-29 04:36:52 +00:00
d1c1eaeddc clean: remove debug prints, multirow kernel complete with multi-tile KV merge 2026-05-28 23:57:31 +00:00
c65baabcc9 fix: V tile copy — V is (HD, SK_TOTAL) so tile columns are not contiguous 2026-05-28 23:55:52 +00:00
869460a932 debug: add LSE verification and merge debug prints 2026-05-28 23:54:30 +00:00
2f2259395e fix: always normalize in kernel, correct KV merge with normalized O + LSE 2026-05-28 23:53:44 +00:00
914f76d30c multirow: add normalize flag, un-norm + LSE output, multi-tile KV merge test 2026-05-28 23:51:23 +00:00
ca5cf0e517 test: add multi-head and batched prefill tests for multirow kernel 2026-05-28 23:48:53 +00:00
0220e51d18 fix: typo cudaErrorCudaSuccess -> cudaSuccess 2026-05-28 23:18:21 +00:00
468614a4e2 fmha_multirow: non-interleaved design — softmax first, then PV
KEY FIX: TMEM is shared between QK output (S) and PV output (O).
Cannot interleave softmax reads with PV writes because PV overwrites S.

New flow:
1. QK GEMM → S in TMEM
2. Softmax: read ALL S from TMEM, compute P in registers
   - Pass 1: row_max (4 warps, 32x32b.x8)
   - Pass 2: exp, sum, store P in p_vals[SK_TILE] registers
3. PV GEMM: write P to sPk per K-tile, accumulate O in TMEM
4. Epilogue: read O from TMEM, normalize, write GMEM

P in registers: each lane holds float p_vals[128] = 512 bytes.
Register budget: 128 lanes × 512B = 64KB (within B200 256KB register file).
2026-05-28 23:17:43 +00:00
c768abed95 test: softmax-only kernel (QK + row_max, no PV) 2026-05-28 23:15:36 +00:00
d840fbbf85 test: clean multirow test with proper SMEM calc 2026-05-28 23:10:49 +00:00
f2124b9378 fix: SMEM calc in decode test 2026-05-28 23:08:54 +00:00
58ff781388 test: simplified decode kernel for debugging multirow 2026-05-28 23:08:33 +00:00
ff8c677486 fix: SMEM size for MMA test — account for both sQ0 and sK0 2026-05-28 23:06:07 +00:00
fee022a485 test: MMA→4-warp read using proven fmha_common+umma_desc infra 2026-05-28 23:05:29 +00:00
e1a708a187 test: try 16x256b.x1 with column step=4 (4 cols per read) 2026-05-28 23:03:51 +00:00
95003eced2 test: 16x256b.x1 loads with uint32_t regs, matching working pattern 2026-05-28 23:03:10 +00:00
fffb493b0e fix: 16x256b.x1 load syntax — single address operand 2026-05-28 23:02:23 +00:00
44dcd6e8d0 test: 16x256b.x1 multiple LOADS — do they crash like stores? 2026-05-28 23:02:03 +00:00
d54bce6a6d fix: correct SMEM size for MMA 4-warp test 2026-05-28 23:01:12 +00:00
be45e87891 test: MMA→4-warp TMEM read — do warps see different rows? 2026-05-28 23:00:27 +00:00
6b0d57074a test: TMEM cross-warp visibility with different sync strategies 2026-05-28 22:59:31 +00:00
77d190278e test: simpler TMEM 4-warp read — direct store+load 2026-05-28 22:58:48 +00:00
91b03bd6bd test: verify 4-warp TMEM read with 32x32b.x8 after MMA 2026-05-28 22:57:59 +00:00
28e04a5ea8 fix: use __cvta_generic_to_shared directly for 64-bit compat 2026-05-28 22:56:29 +00:00
1d6a95df32 fix: typo in tmem row offset test 2026-05-28 22:56:15 +00:00
cf6fe71368 test: verify TMEM 32x32b.x8 row offset addressing 2026-05-28 22:56:00 +00:00
4cfb707405 fix: correct SMEM size calculation in multirow test 2026-05-28 22:53:46 +00:00
863a030c3b fmha_multirow: rewrite with 32x32b.x8 only, no s_p_vals, row_page addressing
- Kill 64KB s_p_vals buffer — P is streamed per K-tile through sPk
- All TMEM ops use 32x32b.x8 exclusively (16x256b.x1 crashes on 2nd call)
- T>32: 4 softmax warps use row_page offset in TMEM address (row<<16)
- Lane l in warp w handles row w*32+l
- Two-pass softmax: pass 1 row_max, pass 2 exp/sum interleaved with PV
- PV: N=16 sub-tiles, SS MMA sPk(128,16) × sV(16,16) → TMEM
- Epilogue: 32x32b.x8 TMEM read, normalize, BF16 → GMEM
- SMEM budget: ~14KB (well within 232KB)
2026-05-28 22:52:52 +00:00
08694b8136 Fix multi-row softmax v3: 32x32b.x8 with per-lane per-row (no wmax/wsum), per-row sRowMax/sRowSum arrays 2026-05-28 20:10:13 +00:00
bf4dfd131b Fix nvcc goto-bypasses-init: move var decls before goto targets 2026-05-28 20:04:59 +00:00
2b09d4f2ef Fix nvcc goto-bypasses-init in multi-row test 2026-05-28 20:04:45 +00:00
d8b421ccee Multi-row FMHA kernel (Milestone 4): T>1 prefill support with 4-warp parallel softmax 2026-05-28 20:04:29 +00:00
3fd302e7a0 Fix nvcc goto-bypasses-init errors in multi-head test 2026-05-28 19:33:04 +00:00
aa41cfa2e5 Multi-head FMHA kernel (Milestone 5): grid launch with MHA/MQA/batch support
- fmha_6warp_multihead.cuh: grid=(1, n_h, batch) kernel with FmhaParams
- MQA support via k_head_stride=0 / v_head_stride=0
- LSE output for multi-segment KV merge composition
- test_fmha_6warp_multihead.cu: MHA (4+8 heads), MQA, batched tests
- HD-specific wrappers for hd=16/64/128/256
- Marked E2M1 dequant bug as FIXED in consultant issue file
2026-05-28 19:32:35 +00:00
6af2feb42a TMA 5D test: element stride decomposition 2026-05-28 19:18:01 +00:00
96f2f0bb90 auto: pre-test commit 2026-05-28 19:12:23 +00:00
015435b1ab auto: pre-test commit 2026-05-28 19:09:50 +00:00