Commit Graph

1930 Commits

Author SHA1 Message Date
9e6ba25a98 P5: standalone multi-tile CUDA test (2 KV tiles, hd=64) 2026-05-30 09:01:52 +00:00
b61df2657b P5: fix reference attention for MQA/GQA (kv_idx = h // q_per_kv) 2026-05-30 08:59:50 +00:00
c55030a340 P5: clean kernel with runtime branch (single-tile unchanged, multi-tile separate path)
Single-tile path is IDENTICAL to the working pre-P5 kernel.
Multi-tile path uses FA2 online softmax with sOacc accumulator.
Runtime branch on is_multi_tile = (n_kv_tiles > 1).
2026-05-30 08:57:00 +00:00
5f4856d771 P5: fix sOacc init race — use single thread (tid==0) instead of 4 softmax warps 2026-05-30 08:53:50 +00:00
66b126ded8 P5: fix standalone test template — add n_kv_tiles to FmhaParams 2026-05-30 08:50:38 +00:00
0f34f60494 P5: fix single-tile backward compat (normalized P for n_kv_tiles==1) 2026-05-30 08:47:47 +00:00
2649488d13 P5: in-kernel multi-KV-tile FA2 online softmax in fmha_6warp_multihead.cuh
- Kernel loops over KV tiles internally with running max/sum rescale
- SMEM accumulator sOacc[hd] replaces TMEM accumulation across tiles
- P is UN-NORMALIZED for multi-tile (exp(s-max), not /sum)
- Per KV tile: QK→softmax→PV→TMEM→read→add to sOacc
- Final: O = sOacc / running_sum
- Single tile (n_kv_tiles=1): same as before, no rescale
- Updated CAPI, Python loader, production.py fast path
- Added multi-tile test cases (N=256, 512)
2026-05-30 08:46:09 +00:00
6421f7c3f3 P4 RESOLVED: TMA hang was GMEM misalignment, not descriptor/driver issue
Evidence: TMA loads succeed with 128B-aligned GMEM on all descriptor configs.
The bit-21 workaround was NOT needed. The 'misaligned address' crashes were
caused by passing non-128B-aligned GMEM pointers to cp.async.bulk.tensor.

Added docs/p4_tma_hang_resolution.md with root cause and fix.
Cleaned up stale P4 test files.
2026-05-30 08:42:18 +00:00
58c087416b P4: 128B-aligned GMEM, proper SMEM alignment, bit21 test 2026-05-30 08:41:15 +00:00
90c806733f P4: test TMA with bit-21 workaround and innermost-first dims 2026-05-30 08:40:21 +00:00
16027018df P4: fix TMA load test (32-bit SMEM addrs, proper mbarrier) 2026-05-30 08:38:55 +00:00
e2ecdc42d8 P4: TMA load test kernel (swizzle vs no-swizzle hang diagnosis) 2026-05-30 08:38:11 +00:00
bd104c2ab2 P4: fix OOB fill enum name 2026-05-30 08:37:05 +00:00
cdd1babf1f P4: correct CUDA 13.2 API (dataType before rank, FloatOOBfill, globalDim) 2026-05-30 08:36:24 +00:00
8df3ccecea P4: CUDA 13.2 has 10-param cuTensorMapEncodeTiled (no OOB fill) 2026-05-30 08:35:34 +00:00
d8ffdb66e1 P4: fix API signature rank/dtype order, OOB_FILL defines 2026-05-30 08:35:04 +00:00
277689f8b8 P4: use proper CUDA enum names 2026-05-30 08:34:19 +00:00
6d624a1b14 P4: remove explicit enum casts 2026-05-30 08:33:42 +00:00
4898a946eb P4: fix TMA descriptor dump API order (dtype before rank) 2026-05-30 08:33:12 +00:00
3943be6063 P4: fix TMA descriptor dump (cuuint64_t dims, proper CUtensorMap API) 2026-05-30 08:32:34 +00:00
4df6ea2d8c P4: TMA descriptor dump test (cuTensorMapEncodeTiled) 2026-05-30 08:31:56 +00:00
ae425b5522 P3: clean up test, remove debug files, final integration test
- test_p3_fast_decode.py: clean kernel test + full API test
- Removed debug tests (sanity, v_debug, v_ref_debug)
- Double normalization fix verified: kernel output matches reference
  at cos >= 0.999990 across all MHA/MQA/GQA configs
2026-05-30 08:29:25 +00:00
10915c4e70 fix: remove double normalization in fmha_6warp_multihead epilogue
P was already normalized in softmax step. PV = P_norm @ V gives the
correct attention output. Dividing by row_sum again in the epilogue
produces O = O_correct / row_sum (128x too small for uniform data).
2026-05-30 08:26:20 +00:00
cfac224b59 debug: single head sanity test with known values 2026-05-30 08:25:20 +00:00
1c74d35fb4 debug: V layout reference comparison 2026-05-30 08:24:35 +00:00
a3c5f817e1 debug: compare api vs direct kernel vs reference 2026-05-30 08:23:43 +00:00
78e6d58b85 debug: V layout comparison test 2026-05-30 08:22:49 +00:00
074c4c4f42 P3: call fmha_multihead_decode_raw directly (skip custom op) 2026-05-30 08:21:53 +00:00
1b9cdf89fb P3: add full API integration test 2026-05-30 08:20:53 +00:00
0608d9d09e P3: fix GQA via K/V repeat_interleave, relax threshold to 0.999990 2026-05-30 08:20:01 +00:00
d5c0086737 P3: fix SMEM computation, pad K/V to 128, remove stale files
- fmha_multihead_capi.cu: SMEM formula matches standalone test
  Added cudaFuncSetAttribute for dynamic SMEM > 48KB
- fmha_multihead_op.py: pad K/V to N=128 when N<128
  (kernel softmax loop is hardcoded to SK_TILE=128)
- Removed fmha_multihead_launch.cu (ATen approach, didn't work)
- Removed test_p3_ctypes_minimal.py (superseded by main test)
2026-05-30 08:19:16 +00:00
094b3c9e6c P3: fix test — create V in kernel layout (hd,N), transpose for reference 2026-05-30 08:18:20 +00:00
7b5b3342fa P3: fix integration test — V transpose, direct ctypes call 2026-05-30 08:17:33 +00:00
8a5070aa38 test: minimal ctypes debug test for P3 2026-05-30 08:16:50 +00:00
63645a3c7b fix: -Xcompiler -fPIC instead of -fPIC for nvcc 2026-05-30 08:16:04 +00:00
adcf3e04ab P3: ctypes loader for 6-warp FMHA (bypass torch JIT sm_100 arch issue)
- fmha_multihead_capi.cu: pure C API wrapper, no ATen/pybind11 deps
- fmha_multihead_op.py: nvcc precompile + ctypes load (sm_100a)
- Removed fmha_multihead_launch.cu (ATen approach didn't work)
- Updated test to call kernel directly via ctypes API
2026-05-30 08:15:31 +00:00
1e6adf5e01 P3: wire 6-warp multi-head FMHA decode fast path into production.py
- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
  (c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
  (dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
  Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)

Architecture:
  Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
  MQA: k_head_stride=0 so all Q heads share same K/V
  Single kernel launch, zero cudaDeviceSynchronize on hot path
  Normalized output for single-segment decode
2026-05-30 08:12:23 +00:00
20f3ccd992 D1.5 complete: HD=512 support via hd_chunk tiling with native TMEM columns 2026-05-30 07:02:41 +00:00
f2592ea0da fix: native TMEM columns for hd_chunk (no remapping) 2026-05-30 07:01:42 +00:00
dcf89fdd1c debug: check full HD for chunk1 test 2026-05-30 07:00:46 +00:00
3dbd3c5e7f debug: test chunk 1 only 2026-05-30 07:00:14 +00:00
72779e7f71 debug: compare only first HD_CHUNK values 2026-05-30 06:59:39 +00:00
9227b0e93f debug: skip hd_chunk>0 to isolate chunk0 2026-05-30 06:59:01 +00:00
25aeaca9ab fix: PV accumulate flag 2026-05-30 06:56:53 +00:00
1da785c070 D1.5: HD tiling (HD_CHUNK=256) for HD=512 support 2026-05-30 06:56:09 +00:00
700524f183 test: HD=128/256 variants for D1.5 2026-05-30 04:49:33 +00:00
f2544a4600 test: full matrix for D1.5 multirow multitile 2026-05-30 04:49:00 +00:00
5544d3a0a4 fix: TMEM reads must be outside my_row_active (warp-collective) 2026-05-30 04:48:26 +00:00
1dca8d8cfa debug: unbuffered stdout 2026-05-30 04:46:11 +00:00
8be8813d54 debug: more prints 2026-05-30 04:44:41 +00:00