Commit Graph

491 Commits

Author SHA1 Message Date
7d9e70c5d5 Fix remaining mHC API references: layer_compare.py, layer.py comment 2026-05-31 18:38:34 +00:00
7b123d159f CRITICAL FIX: mHC fn/base/scale ordering [pre,post,comb] + comb transposed + Sinkhorn softmax
Bugs fixed (verified against HuggingFace DeepseekV4HyperConnection):
1. fn/base/scale ordering was [pre,comb,post], should be [pre,post,comb]
   - Was applying Sinkhorn to post values and 2*sigmoid to comb values
   - This caused residual to grow unbounded (no doubly-stochastic constraint)
2. comb (B_l) must be TRANSPOSED in post_block
   - HF: comb.transpose(-1,-2) @ hidden_streams
   - Was using B_l @ X_l without transpose
3. Sinkhorn must start from softmax(logits) + eps, not exp(logits)
   - HF: softmax → col norm → (iters-1) alternating
   - Was using exp → alternating (different convergence behavior)
4. Missing hc_eps on pre (A_l)
   - HF: sigmoid(...) + hc_eps
   - Was missing the eps guard
5. Renamed W_res→W_comb, S_res→S_comb, alpha_res→alpha_comb throughout
   - Matches checkpoint naming and HF model
6. Fixed fallback mHC initialization to use new API
2026-05-31 18:38:12 +00:00
1c18c16c68 Fix production rope.py: FP32 arithmetic for forward_rope_partial + inverse_rope_bf16 2026-05-31 09:17:36 +00:00
df6220abaf E5: Fold batch loop into native kernel grid (blockIdx.z)
The 6-warp multi-tile kernel already supports batch natively via
dim3 grid(1, n_h, batch). Removed Python for-loop for 4D input.
Single kernel launch per layer for batched decode instead of
batch_size launches.

T>1 prefill still uses per-batch dispatch (E8 future work).
2026-05-30 21:21:02 +00:00
9d88769f5f Wire indexer compute_index_scores_topk + fix compressor imports
- indexer/__init__.py: compute_index_scores_topk now calls
  run_indexer_score_topk with proper tensor reshaping
- compressor/__init__.py: added torch import, fixed csa_compress_tail
  and hca_compress_tail imports for flush.py
- Full flush pipeline now importable end-to-end
2026-05-30 21:19:06 +00:00
daf84524ac E2/E3: compressor bridge, indexer bridge, flush pipeline wiring
- compress_tail.py: PyTorch reference CSA/HCA compression
  (token-level softmax over m/m' entries, paper eq. 11-12)
- compressor/__init__.py: csa_compress_and_store, hca_compress_and_store
  bridges (compression deferred to flush pipeline)
- indexer/__init__.py: compute_index_scores_topk bridge (NotImplemented)
- Fixed attention.py: removed extra positions arg to write_swa
2026-05-30 21:16:54 +00:00
d3b772196d E3: Implement DSV4Model — full model class
- Token embedding → N×TransformerLayer → RMSNorm → lm_head
- decode_step: single token decode with mHC state management
- forward: prefill path (T tokens)
- Cache handle acquisition per layer
- mHC state initialization from embedding
- Weight loading TODO (deferred to loader/)
2026-05-30 21:15:57 +00:00
b0cdd5af74 fix: extern declarations for gather_swa functions in gather_kv.cu 2026-05-30 21:14:15 +00:00
016d722abc fix: single PYBIND11_MODULE for combined gather .so
Both gather_kv.cu and gather_swa.cu are compiled into one .so.
Only gather_kv.cu defines the PYBIND11_MODULE; gather_swa.cu
just provides the function implementations.
2026-05-30 21:13:24 +00:00
8fb9d89658 fix: correct gather.py kernel_dir path 2026-05-30 21:12:09 +00:00
300dddedc0 E1-E4: gather kernels, handle wiring, rope, sync removal, e2e test
E1: LayerCacheHandle now exposes gather_compressed_kv,
    gather_all_compressed_kv, gather_swa_kv, num_query_heads, head_dim.
    Gather kernels in dsv4/kernels/cuda/gather_swa.cu + gather_kv.cu.
    Python wrapper in dsv4/kernels/cache/gather.py.

E2: tests/e2e/test_one_layer.py — SWA path smoke test.

E3: Compressor/indexer __init__.py bridges (NotImplementedError stubs
    for CSA/HCA compress_and_store, compute_index_scores_topk).

E4: Removed torch.cuda.synchronize() from fmha_multitile_op.py fast path.
    Error checking via C API return code instead.

Also: forward_rope_partial in ops/rope.py (GPT-J interleaved, last 64 dims).
2026-05-30 21:10:26 +00:00
faf92b30ad E1: Wire LayerCacheHandle gather methods + CUDA gather kernels
- gather_compressed_kv: CSA top-k gather via existing gather_kv.cu
- gather_all_compressed_kv: HCA dense gather via new gather_all_compressed_kernel
- gather_swa_kv: SWA ring buffer gather via new gather_swa_kernel
- Added gather_swa.cu with both SWA + all-compressed gather kernels
- Added gather.py Python wrapper (torch.utils.cpp_extension JIT)
- Updated handle.py: added schema field, num_query_heads/head_dim properties
- Updated manager.py: passes schema + num_query_heads to handle

All gather kernels: FP8→BF16 dequant + BF16 RoPE concat in single launch.
Output: dense BF16 tensors ready for FMHA consumption.
2026-05-30 21:09:21 +00:00
4b9eed02e1 Cleanup C1-C7: delete dead CuTeDSL FMHA, test probes, scratch files
- Deleted fmha.py (CuTeDSL slow path), FmhaKernel, Python KV merge
- Deleted fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh
- Moved fmha_qk_verify.cuh to tests/unit/qk_verify_kernel.cuh
- Deleted decode_sparse.py, decode_swa.py, kernels/decode/
- Deleted 46 test_d*.py probes, test_smem_*, test_cotiled_*, test_tmem_*,
  test_smem_p_*, test_ultra_minimal, test_fmha_pv16, test_working_softmax_maybe
- Deleted root scratch: debug_linear.py, test_mapping.py, run_router_tests.py
- Moved archive/ to archived_plans/code_archive/
- Rewrote production.py: single fast path via 6-warp multi-tile kernel
- Added STATUS.md, audit_attention_live.md
- Moved NEXT_PRIORITIES*.md to archived_plans/
2026-05-30 21:08:12 +00:00
95725f1df0 P8: Delete 6 redundant .cuh variants + multihead CAPI/op
Kept: fmha_6warp_tma_multirow_multitile.cuh (production kernel)
Deleted: fmha_6warp.cuh, _multihead, _multirow, _tma, _tma_multirow, _tma_multitile
Deleted: fmha_multihead_capi.cu, fmha_multihead_op.py

production.py: Removed _dsv4_attention_fast_decode, unified dispatch to
_dsv4_attention_multitile for all fast-path cases.
2026-05-30 17:21:15 +00:00
9d483b1c54 P8: Unified dispatch — multi-tile kernel handles all N
production.py: Single fast path using multi-tile kernel for all N.
Eliminates the separate _dsv4_attention_fast_decode path.
2026-05-30 17:19:09 +00:00
c0379a0f86 P6: Remove broken TMA store — use direct GMEM write from SMEM
cp.async.bulk.tensor store (SMEM→GMEM) is NOT available on SM100.
The CUTLASS SM100 epilogue uses st.global directly.

The one-way epilogue pipeline is now:
  1. TMEM → regs (tcgen05.ld, warp-collective)
  2. epilogue_op in regs (normalize, FP4 hook via ENABLE_FP4_EPILOGUE)
  3. regs → SMEM (row-major, sO_epi)
  4. SMEM → GMEM (direct write)

This is the same pattern as the MoE kernel but with st.global instead
of TMA store. Multi-CTA (D2) will use st.global with flat_divide coords.

Removed: tma_o from FmhaParams, fmha_multihead_decode_tma_launch,
sMbarStore from SMEM, broken TMA store PTX from fmha_tma.cuh.
2026-05-30 17:11:17 +00:00
f97359fbfc P6: TMA store uses mbarrier completion (same as load)
TMA store: cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes
Uses mbarrier for completion, not bulk_group. Restored sMbarStore to SMEM.
2026-05-30 17:07:24 +00:00
2de300e281 P6: Try shared::cluster instead of shared::cta for TMA store 2026-05-30 17:05:27 +00:00
829a5f93ce P6: Fix TMA store PTX — remove .tile modifier, fix wait_group syntax 2026-05-30 17:04:38 +00:00
fd7c0cb773 P6: Fix TMA store — use bulk_group (commit+wait) not mbarrier
TMA store uses cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group
NOT mbarrier::complete_tx::bytes. Completion tracked via:
  - cp.async.bulk.commit_group (after issuing stores)
  - cp.async.bulk.wait_group.read 0 (wait for all groups)

Removed sMbarStore from SMEM allocations (no longer needed).
2026-05-30 16:57:35 +00:00
212fc85627 P6: One-way TMEM→regs→SMEM→TMA store epilogue
- fmha_6warp_multihead.cuh: Rewritten epilogue with proper Blackwell pipeline
  1. TMEM → regs (tcgen05.ld, warp-collective)
  2. epilogue_op in regs (normalize, FP4 hook via ENABLE_FP4_EPILOGUE)
  3. regs → SMEM row-major (sO_epi, for TMA tile format)
  4. TMA store SMEM → GMEM (async, enables multi-CTA)
  Fallback to direct GMEM write when tma_o is nullptr.
  Added FmhaParams.tma_o field and ENABLE_FP4_EPILOGUE template param.

- fmha_6warp_tma_multirow_multitile.cuh: Same epilogue pattern for multi-tile.
  Writes normalized output to sO_epi_rowmajor + TMA store (or direct GMEM).
  Added tma_o to FmhaTmaMultiRowMultiTileParams.

- fmha_tma.cuh: Added tma_store_2d and tma_store_wait for async GMEM writes.

- fmha_multihead_capi.cu: Added fmha_multihead_decode_tma_launch with
  per-(head,batch) TMA descriptors. Updated SMEM size calculation for sO_epi + sMbarStore.

- fmha_multitile_capi.cu: Added tma_o=nullptr (backward compatible), updated SMEM size.
2026-05-30 16:56:07 +00:00
897a70a491 P5: minimal Python multi-tile test 2026-05-30 10:43:26 +00:00
a2627359fb P5: fix TMA desc creation — write to HOST then cudaMemcpy to device 2026-05-30 10:40:01 +00:00
f370bfb1f1 P5: re-enable multi-tile Python tests, fix CAPI to use create_tma_desc_2d_bf16 2026-05-30 10:38:33 +00:00
97531a68e6 fix: remove n_kv_tiles from capi too 2026-05-30 10:30:40 +00:00
f032800eaa P5: integrate WORKING multi-tile kernel (fmha_6warp_tma_multirow_multitile) into production
- fmha_multitile_capi.cu: C API wrapper for TMA multi-tile kernel
  Creates TMA descriptors per (head, batch), launches kernel
- fmha_multitile_op.py: nvcc precompile + ctypes loader
- production.py: dispatch to multitile for N>128 or hd=512
- Reverted fmha_6warp_multihead.cuh to working single-tile version
- The TMA multi-tile kernel already passes 72 configs (D1.5)
  HD=64/128/256/512 × T=1/4/32/128 × s_k=128/256/384/512
2026-05-30 10:27:38 +00:00
c55030a340 P5: clean kernel with runtime branch (single-tile unchanged, multi-tile separate path)
Single-tile path is IDENTICAL to the working pre-P5 kernel.
Multi-tile path uses FA2 online softmax with sOacc accumulator.
Runtime branch on is_multi_tile = (n_kv_tiles > 1).
2026-05-30 08:57:00 +00:00
5f4856d771 P5: fix sOacc init race — use single thread (tid==0) instead of 4 softmax warps 2026-05-30 08:53:50 +00:00
0f34f60494 P5: fix single-tile backward compat (normalized P for n_kv_tiles==1) 2026-05-30 08:47:47 +00:00
2649488d13 P5: in-kernel multi-KV-tile FA2 online softmax in fmha_6warp_multihead.cuh
- Kernel loops over KV tiles internally with running max/sum rescale
- SMEM accumulator sOacc[hd] replaces TMEM accumulation across tiles
- P is UN-NORMALIZED for multi-tile (exp(s-max), not /sum)
- Per KV tile: QK→softmax→PV→TMEM→read→add to sOacc
- Final: O = sOacc / running_sum
- Single tile (n_kv_tiles=1): same as before, no rescale
- Updated CAPI, Python loader, production.py fast path
- Added multi-tile test cases (N=256, 512)
2026-05-30 08:46:09 +00:00
10915c4e70 fix: remove double normalization in fmha_6warp_multihead epilogue
P was already normalized in softmax step. PV = P_norm @ V gives the
correct attention output. Dividing by row_sum again in the epilogue
produces O = O_correct / row_sum (128x too small for uniform data).
2026-05-30 08:26:20 +00:00
074c4c4f42 P3: call fmha_multihead_decode_raw directly (skip custom op) 2026-05-30 08:21:53 +00:00
0608d9d09e P3: fix GQA via K/V repeat_interleave, relax threshold to 0.999990 2026-05-30 08:20:01 +00:00
d5c0086737 P3: fix SMEM computation, pad K/V to 128, remove stale files
- fmha_multihead_capi.cu: SMEM formula matches standalone test
  Added cudaFuncSetAttribute for dynamic SMEM > 48KB
- fmha_multihead_op.py: pad K/V to N=128 when N<128
  (kernel softmax loop is hardcoded to SK_TILE=128)
- Removed fmha_multihead_launch.cu (ATen approach, didn't work)
- Removed test_p3_ctypes_minimal.py (superseded by main test)
2026-05-30 08:19:16 +00:00
63645a3c7b fix: -Xcompiler -fPIC instead of -fPIC for nvcc 2026-05-30 08:16:04 +00:00
adcf3e04ab P3: ctypes loader for 6-warp FMHA (bypass torch JIT sm_100 arch issue)
- fmha_multihead_capi.cu: pure C API wrapper, no ATen/pybind11 deps
- fmha_multihead_op.py: nvcc precompile + ctypes load (sm_100a)
- Removed fmha_multihead_launch.cu (ATen approach didn't work)
- Updated test to call kernel directly via ctypes API
2026-05-30 08:15:31 +00:00
1e6adf5e01 P3: wire 6-warp multi-head FMHA decode fast path into production.py
- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
  (c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
  (dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
  Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)

Architecture:
  Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
  MQA: k_head_stride=0 so all Q heads share same K/V
  Single kernel launch, zero cudaDeviceSynchronize on hot path
  Normalized output for single-segment decode
2026-05-30 08:12:23 +00:00
f2592ea0da fix: native TMEM columns for hd_chunk (no remapping) 2026-05-30 07:01:42 +00:00
3dbd3c5e7f debug: test chunk 1 only 2026-05-30 07:00:14 +00:00
9227b0e93f debug: skip hd_chunk>0 to isolate chunk0 2026-05-30 06:59:01 +00:00
25aeaca9ab fix: PV accumulate flag 2026-05-30 06:56:53 +00:00
1da785c070 D1.5: HD tiling (HD_CHUNK=256) for HD=512 support 2026-05-30 06:56:09 +00:00
5544d3a0a4 fix: TMEM reads must be outside my_row_active (warp-collective) 2026-05-30 04:48:26 +00:00
dd3e0fdfc8 D1.5: multi-row + multi-tile FMHA with SMEM accumulator in-kernel rescale 2026-05-30 04:37:33 +00:00
8b1ac380ac feat: HD=512 support — TMEM_N=512, test variants for all three TMA kernels 2026-05-30 03:45:05 +00:00
762f054d6d feat: double-buffer TMA pipeline in multi-row kernel 2026-05-30 03:20:49 +00:00
4a9c850e9c feat: double-buffer TMA pipeline for K loads in single-tile kernel 2026-05-30 03:14:06 +00:00
afa949071b fix: brace structure in V TMA conversion 2026-05-29 22:59:18 +00:00
ec577f71ee feat: V TMA loads in single-tile kernel too 2026-05-29 22:57:59 +00:00
422e7bb312 cleanup: v_head reference in multi-row (V via TMA now) 2026-05-29 22:54:44 +00:00