Commit Graph

1652 Commits

Author SHA1 Message Date
85cd95e609 debug: TMA context fix test 2026-05-29 04:45:54 +00:00
76c82ebdcd debug: detailed TMA descriptor debug test 2026-05-29 04:45:06 +00:00
0c9245b4d2 fix: add cuInit(0) for CUDA driver API 2026-05-29 04:43:24 +00:00
6cc2f61431 debug: TMA descriptor dimension test 2026-05-29 04:42:44 +00:00
3412ff1a9b fix: TMA tile strides must match global strides, not tile dimensions
The tile stride in the outer dimension should be the global row stride
(cols), not the tile width. The tile is a window into the global tensor
and elements are addressed with global strides.
2026-05-29 04:41:53 +00:00
409838ace2 refactor: per-sub-tile TMA loads with padded GMEM allocations
- Q, K, V all loaded per (128,16) sub-tile via TMA
- Q GMEM padded to (128, HD) to satisfy TMA tile requirements
- Simpler SMEM layout — only (128,16) staging buffers needed
- Updated test with padded allocations
2026-05-29 04:41:03 +00:00
8c17f65f5b fix: cast typo 2026-05-29 04:39:21 +00:00
8908b697dd fix: bool type mismatch 2026-05-29 04:39:12 +00:00
b78ebe8a9c debug: add TMA descriptor error reporting 2026-05-29 04:38:57 +00:00
c7a6d7d231 fix: tma_mbar_init → tma_mbarrier_init (typo) 2026-05-29 04:37:48 +00:00
696462f07a feat: TMA async load infrastructure for FMHA kernel
- fmha_tma.cuh: TMA descriptor creation, mbarrier helpers, cp.async.bulk.tensor.2d wrappers
- fmha_6warp_tma.cuh: TMA-integrated multirow kernel with async GMEM→SMEM loads
  - TMA loads Q, K, V tiles to row-major SMEM
  - Transposes to canonical K-major layout for MMA
  - Same softmax/epilogue as non-TMA kernel
- test_fmha_tma.cu: Test harness for TMA FMHA (HD=64 first)
2026-05-29 04:36:52 +00:00
d1c1eaeddc clean: remove debug prints, multirow kernel complete with multi-tile KV merge 2026-05-28 23:57:31 +00:00
c65baabcc9 fix: V tile copy — V is (HD, SK_TOTAL) so tile columns are not contiguous 2026-05-28 23:55:52 +00:00
869460a932 debug: add LSE verification and merge debug prints 2026-05-28 23:54:30 +00:00
2f2259395e fix: always normalize in kernel, correct KV merge with normalized O + LSE 2026-05-28 23:53:44 +00:00
914f76d30c multirow: add normalize flag, un-norm + LSE output, multi-tile KV merge test 2026-05-28 23:51:23 +00:00
ca5cf0e517 test: add multi-head and batched prefill tests for multirow kernel 2026-05-28 23:48:53 +00:00
ac8fa779e2 fix: move epilogue TMEM loads outside my_row_active guard (warp-collective hang) 2026-05-28 23:46:46 +00:00
55c0604a71 add fence.sc.gpu between PV and epilogue for TMEM visibility 2026-05-28 23:21:53 +00:00
52809b0ec6 fix: tcgen05.wait::ld.sync.aligned (was missing 'sync') 2026-05-28 23:19:03 +00:00
0220e51d18 fix: typo cudaErrorCudaSuccess -> cudaSuccess 2026-05-28 23:18:21 +00:00
468614a4e2 fmha_multirow: non-interleaved design — softmax first, then PV
KEY FIX: TMEM is shared between QK output (S) and PV output (O).
Cannot interleave softmax reads with PV writes because PV overwrites S.

New flow:
1. QK GEMM → S in TMEM
2. Softmax: read ALL S from TMEM, compute P in registers
   - Pass 1: row_max (4 warps, 32x32b.x8)
   - Pass 2: exp, sum, store P in p_vals[SK_TILE] registers
3. PV GEMM: write P to sPk per K-tile, accumulate O in TMEM
4. Epilogue: read O from TMEM, normalize, write GMEM

P in registers: each lane holds float p_vals[128] = 512 bytes.
Register budget: 128 lanes × 512B = 64KB (within B200 256KB register file).
2026-05-28 23:17:43 +00:00
c768abed95 test: softmax-only kernel (QK + row_max, no PV) 2026-05-28 23:15:36 +00:00
43ba672e15 fmha_multirow: add fence.sc.gpu after QK GEMM for TMEM visibility 2026-05-28 23:13:31 +00:00
d840fbbf85 test: clean multirow test with proper SMEM calc 2026-05-28 23:10:49 +00:00
f2124b9378 fix: SMEM calc in decode test 2026-05-28 23:08:54 +00:00
58ff781388 test: simplified decode kernel for debugging multirow 2026-05-28 23:08:33 +00:00
be2685e9e3 fmha_multirow: use natural 4-warp TMEM partitioning after UMMA
After UMMA (QK GEMM), 4 warps reading TMEM with 32x32b.x8 each
see a different 32-row partition (verified on B200):
  Warp 0 → rows 0-31, Warp 1 → rows 32-63, etc.
Lane l in warp w reads row w*32 + l.

This eliminates the broken row_page<<16 addressing and allows:
- T<=32: warp 0 only, 32x32b.x8, each lane = one row
- T>32: 4 warps, each reads its natural 32-row partition
- Epilogue: same partitioning for reading O from TMEM

No s_p_vals buffer. P streamed per K-tile through sPk.
2026-05-28 23:07:31 +00:00
ff8c677486 fix: SMEM size for MMA test — account for both sQ0 and sK0 2026-05-28 23:06:07 +00:00
fee022a485 test: MMA→4-warp read using proven fmha_common+umma_desc infra 2026-05-28 23:05:29 +00:00
e1a708a187 test: try 16x256b.x1 with column step=4 (4 cols per read) 2026-05-28 23:03:51 +00:00
95003eced2 test: 16x256b.x1 loads with uint32_t regs, matching working pattern 2026-05-28 23:03:10 +00:00
fffb493b0e fix: 16x256b.x1 load syntax — single address operand 2026-05-28 23:02:23 +00:00
44dcd6e8d0 test: 16x256b.x1 multiple LOADS — do they crash like stores? 2026-05-28 23:02:03 +00:00
d54bce6a6d fix: correct SMEM size for MMA 4-warp test 2026-05-28 23:01:12 +00:00
be45e87891 test: MMA→4-warp TMEM read — do warps see different rows? 2026-05-28 23:00:27 +00:00
6b0d57074a test: TMEM cross-warp visibility with different sync strategies 2026-05-28 22:59:31 +00:00
77d190278e test: simpler TMEM 4-warp read — direct store+load 2026-05-28 22:58:48 +00:00
91b03bd6bd test: verify 4-warp TMEM read with 32x32b.x8 after MMA 2026-05-28 22:57:59 +00:00
28e04a5ea8 fix: use __cvta_generic_to_shared directly for 64-bit compat 2026-05-28 22:56:29 +00:00
1d6a95df32 fix: typo in tmem row offset test 2026-05-28 22:56:15 +00:00
cf6fe71368 test: verify TMEM 32x32b.x8 row offset addressing 2026-05-28 22:56:00 +00:00
4cfb707405 fix: correct SMEM size calculation in multirow test 2026-05-28 22:53:46 +00:00
863a030c3b fmha_multirow: rewrite with 32x32b.x8 only, no s_p_vals, row_page addressing
- Kill 64KB s_p_vals buffer — P is streamed per K-tile through sPk
- All TMEM ops use 32x32b.x8 exclusively (16x256b.x1 crashes on 2nd call)
- T>32: 4 softmax warps use row_page offset in TMEM address (row<<16)
- Lane l in warp w handles row w*32+l
- Two-pass softmax: pass 1 row_max, pass 2 exp/sum interleaved with PV
- PV: N=16 sub-tiles, SS MMA sPk(128,16) × sV(16,16) → TMEM
- Epilogue: 32x32b.x8 TMEM read, normalize, BF16 → GMEM
- SMEM budget: ~14KB (well within 232KB)
2026-05-28 22:52:52 +00:00
1ba304db3e stuff 2026-05-28 21:08:13 +00:00
deaa3ec725 CRITICAL FIX: Q/K SMEM canonical layout must use local d (0..15) not full_d — UMMA descriptor reads from sQ0/sK0 start, not offset 2026-05-28 20:13:52 +00:00
08694b8136 Fix multi-row softmax v3: 32x32b.x8 with per-lane per-row (no wmax/wsum), per-row sRowMax/sRowSum arrays 2026-05-28 20:10:13 +00:00
aaa76c1af1 Rewrite multi-row softmax using 16x256b.x1 TMEM reads for proper multi-row access 2026-05-28 20:08:30 +00:00
5e3c61184c Fix multi-row softmax: remove cross-lane wmax/wsum — each lane handles its own row independently 2026-05-28 20:06:16 +00:00
bf4dfd131b Fix nvcc goto-bypasses-init: move var decls before goto targets 2026-05-28 20:04:59 +00:00