Commit Graph

1930 Commits

Author SHA1 Message Date
570396b4be debug: simplify test, add fflush 2026-05-30 04:42:35 +00:00
0ad35f8be6 debug: add prints to multirow multitile test 2026-05-30 04:40:06 +00:00
dd3e0fdfc8 D1.5: multi-row + multi-tile FMHA with SMEM accumulator in-kernel rescale 2026-05-30 04:37:33 +00:00
10ae8f3346 auto: pre-test commit 2026-05-30 03:46:38 +00:00
8b1ac380ac feat: HD=512 support — TMEM_N=512, test variants for all three TMA kernels 2026-05-30 03:45:05 +00:00
762f054d6d feat: double-buffer TMA pipeline in multi-row kernel 2026-05-30 03:20:49 +00:00
4a9c850e9c feat: double-buffer TMA pipeline for K loads in single-tile kernel 2026-05-30 03:14:06 +00:00
afa949071b fix: brace structure in V TMA conversion 2026-05-29 22:59:18 +00:00
ec577f71ee feat: V TMA loads in single-tile kernel too 2026-05-29 22:57:59 +00:00
422e7bb312 cleanup: v_head reference in multi-row (V via TMA now) 2026-05-29 22:54:44 +00:00
88c72a887e feat: V TMA loads in multi-row kernel 2026-05-29 22:51:24 +00:00
13403d2808 cleanup: remove unused v_head in multi-tile (V via TMA) 2026-05-29 22:48:50 +00:00
74145a31cc feat: V TMA loads in multi-tile kernel 2026-05-29 22:46:21 +00:00
680d2ebf64 test: V TMA diagnostic — isolate V TMA descriptor issue 2026-05-29 22:42:46 +00:00
077fbdf3c5 test: HD=128/256 multi-tile variants 2026-05-29 20:02:00 +00:00
7df17384fd test: multi-tile s_k=128/256/384/512 2026-05-29 19:59:21 +00:00
d47b2bfcce fix: use un-normalized P for multi-tile PV (correct online softmax merge) 2026-05-29 19:57:54 +00:00
43ae3e7f98 fix: reload Q per-K-sub-tile in multi-tile kernel (same as single-tile) 2026-05-29 19:56:35 +00:00
7598d548ee debug: test multi-tile with s_k=128 only 2026-05-29 19:53:02 +00:00
8e99bd50e6 feat: 6-warp TMA multi-tile KV kernel with register accumulator + test 2026-05-29 19:49:53 +00:00
1814510195 wip: add n_kv_tiles param for multi-tile KV (not yet used) 2026-05-29 19:47:48 +00:00
d20792aa9d fix: TMA descriptor index for batched multi-head (batch*n_h + head) 2026-05-29 19:45:44 +00:00
754c6a692c feat: per-head TMA descriptors for multi-head FMHA 2026-05-29 19:44:58 +00:00
9eb193458e test: refactored multi-row TMA test with multi-head and batch 2026-05-29 19:43:41 +00:00
832a04181d test: relax relative error threshold to 5% for BF16, use cosine > 0.999 as pass criterion 2026-05-29 19:41:40 +00:00
bfef94f5d0 test: HD=128/256 multi-row TMA FMHA 2026-05-29 19:40:32 +00:00
a1b2ab79a1 feat: 6-warp TMA FMHA multi-row kernel + test 2026-05-29 19:39:17 +00:00
d0a50f1f2e fix: remove double normalization in TMA epilogue (P already normalized before PV) 2026-05-29 19:36:41 +00:00
fb971781aa fix: revert V to direct load (V TMA needs debugging), K TMA works 2026-05-29 19:35:44 +00:00
cd2c028b39 feat: TMA loads for both K and V in 6-warp FMHA kernel 2026-05-29 19:34:48 +00:00
523d3838a2 test: HD=128/256 variants for TMA FMHA 2026-05-29 19:32:49 +00:00
bd4f09d514 fix: ambiguous MMA_K_BF16 in test 2026-05-29 19:32:15 +00:00
4459ddefdd feat: 6-warp TMA FMHA kernel + test — TMA for K loads 2026-05-29 19:32:02 +00:00
7a8ba8eeb6 fix: SMEM size calculation — TILE_SZ is in BF16 elements, need *sizeof(bf16_t) for bytes 2026-05-29 19:30:50 +00:00
aac1b25442 test: TMA QK diagnostic — 3 variants to isolate failure 2026-05-29 19:29:35 +00:00
9dfada6626 test: TMA + canonical + QK GEMM incremental 2026-05-29 19:28:23 +00:00
0435e229bd fix: typo cuda_SUCCESS -> cudaSuccess 2026-05-29 19:27:30 +00:00
74514e2680 test: TMA sub-tile load — exact pattern from test_qk_softmax 2026-05-29 19:26:56 +00:00
e449d6d5e1 test: TMA diagnostic with 192 threads 2026-05-29 19:26:09 +00:00
0b36b6047a test: TMA diagnostic with 128 threads 2026-05-29 19:25:38 +00:00
a766b488c2 test: minimal TMA diagnostic — isolate multi-warp TMA bug 2026-05-29 19:25:01 +00:00
fe3b6b8d13 test: QK+softmax T=1 first 2026-05-29 19:12:26 +00:00
a9a87fe7b8 fix: P write with lane stride, use sRowSum 2026-05-29 19:11:19 +00:00
fd6a9b00ae test: QK + softmax — verify P values against reference 2026-05-29 19:10:08 +00:00
5eff53c145 fix: SMEM layout and printf in PV-only test 2026-05-29 19:08:39 +00:00
106f103c83 test: PV-only GEMM — isolate PV from full FMHA pipeline 2026-05-29 19:06:52 +00:00
5542a9da00 debug: V loaded directly from GMEM (not TMA) to isolate PV issue 2026-05-29 18:57:42 +00:00
2262e10fca fix: PV GEMM — V canonical uses CORES_MN_V=2 (block_mn=16), not 16
V is the B operand with block_mn=16 in the PV MMA. Its canonical layout
uses CORES_MN=16/8=2, not 128/8=16. The previous code used CORES_MN=16
which produced wrong canonical indexing → garbage PV output.

Also:
- V SMEM size is (16,16) canonical = 256 BF16, not (128,16) = 2048
- P written as 16 elements at row 0 (T=1 decode)
- V loaded from TMA (16,128) and sub-sampled to (16,16) canonical
- V TMA coord: {col_start, d_base} for (HD,s_k) tensor
2026-05-29 18:54:02 +00:00
90c3372040 refactor: TMA FMHA kernel — 4-warp, proven pattern, full pipeline
Complete rewrite of fmha_6warp_tma.cuh based on lessons learned:
- 128 threads (4 warps) instead of 192 (6 warps) — simpler, proven
- Warp 0: TMA load + softmax, Warp 1: MMA + TMEM alloc
- TMA: mbarrier.arrive.expect_tx (root cause fix), phase parity tracking
- Q loaded directly (T=1 decode), K/V via TMA
- Per-K-sub-tile Q and K loading into (128,16) canonical buffers
- Full softmax + PV GEMM + epilogue pipeline
- Test updated to match new kernel signature
2026-05-29 18:50:58 +00:00
d5e20b2d42 fix: reference should be raw dot product (MMA is unscaled) 2026-05-29 18:48:39 +00:00