Commit Graph

501 Commits

Author SHA1 Message Date
cb2ca8591f fix: add @cute.jit to router compiled function 2026-05-31 23:44:53 +00:00
d5d2b7b4b8 fix: defer router MMA/TMA setup into cute.compile context (matches MoE pattern) 2026-05-31 23:44:00 +00:00
157f1c5258 fix: use OperandMajorMode from nvgpu (not deprecated tcgen05) and mma_tiler_mn in router kernel 2026-05-31 23:39:50 +00:00
1dbc57e2cd fix: use mma_tiler_mn in _create_tiled_mma (attribute exists at init time) 2026-05-31 23:36:01 +00:00
d05dd50bf5 fix: OperandMajorMode.K not MAJOR_K (correct CuTeDSL API) 2026-05-31 23:34:54 +00:00
c5adbbfde6 FMHA sink: don't double-scale sink bias
The sink bias from the checkpoint is already in the scaled domain
(added to QK*scale in the reference softmax). The kernel's
running_max is max(QK*scale), so the sink should be compared
directly without multiplying by scale again.
2026-05-31 23:12:20 +00:00
4adee1207f FMHA: zero-init my_p_vals to fix N<128 padding NaN
When N<128, padded KV positions have my_p_vals[col] uninitialized
for col >= kv_len. The PV GEMM then computes garbage_P × zero_V,
which can produce NaN on tensor cores (0 × NaN = NaN).
Fix: zero-initialize my_p_vals so padded positions contribute 0.
2026-05-31 23:11:12 +00:00
13be3ad443 FMHA sink bias in kernel + single_shot production rewrite
FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh):
- Added sink_bias field to FmhaTmaMultiRowMultiTileParams
- After KV tile loop, sink logit is included in online softmax rescale:
  new_max = max(running_max, sink_bias * scale)
  rescale existing O_unnorm and running_sum
  running_sum += exp(sink_bias * scale - new_max)
  No PV contribution from sink (D5c: single softmax)
- C API: fmha_multitile_decode_launch now takes sink_bias_ptr
- Python: fmha_multitile_decode_raw accepts attn_sink tensor

single_shot_inference.py:
- Full rewrite to use production kernel stack
- mHC: uses dsv4.layers.mhc.mHCLayer (proper Sinkhorn-Knopp)
- Projections: uses Nvfp4Linear (CuTeDSL GEMM) for q_a, q_b, kv, o_b
- FMHA: 6-warp TMA multi-tile with sink bias (no SDPA fallback)
- MoE: Nvfp4MoE + Nvfp4SharedExpert (no reference fallback)
- Router: production dense/hash dispatch
- Compressor/Indexer: reference dequant (not yet on tensor cores)
- NO try/except fallbacks on production paths
2026-05-31 23:10:13 +00:00
92200367f3 FMHA kernel fix: N_orig vs N_padded — correct softmax masking for seq_len < 128
ROOT CAUSE: fmha_multitile_op.py padded N to 128 for TMA alignment
but then passed the PADDED N to the kernel as s_k (logical KV length).
This told the kernel all 128 entries were valid, so softmax ran over
zeros, diluting the result (e.g. 1 valid entry → softmax weight 1/128).

FIX: Pass N_orig (true sequence length) as s_k for softmax masking,
and N_padded (physical size) only for TMA descriptor creation.
The kernel's existing col < kv_len guard correctly excludes padded
entries from row_max and exp_sum calculations.

Files changed:
- fmha_multitile_capi.cu: accept N_orig + N_padded, use N_orig for
  params.s_k and N_padded for TMA descriptors
- fmha_multitile_op.py: pass N_orig and N_padded separately
- single_shot_inference.py: removed SDPA fallback (kernel now correct)
2026-05-31 22:52:39 +00:00
2a886fe0f2 Add --no-thinking mode to skip thinking tokens and use second-best 2026-05-31 19:24:21 +00:00
7d9e70c5d5 Fix remaining mHC API references: layer_compare.py, layer.py comment 2026-05-31 18:38:34 +00:00
7b123d159f CRITICAL FIX: mHC fn/base/scale ordering [pre,post,comb] + comb transposed + Sinkhorn softmax
Bugs fixed (verified against HuggingFace DeepseekV4HyperConnection):
1. fn/base/scale ordering was [pre,comb,post], should be [pre,post,comb]
   - Was applying Sinkhorn to post values and 2*sigmoid to comb values
   - This caused residual to grow unbounded (no doubly-stochastic constraint)
2. comb (B_l) must be TRANSPOSED in post_block
   - HF: comb.transpose(-1,-2) @ hidden_streams
   - Was using B_l @ X_l without transpose
3. Sinkhorn must start from softmax(logits) + eps, not exp(logits)
   - HF: softmax → col norm → (iters-1) alternating
   - Was using exp → alternating (different convergence behavior)
4. Missing hc_eps on pre (A_l)
   - HF: sigmoid(...) + hc_eps
   - Was missing the eps guard
5. Renamed W_res→W_comb, S_res→S_comb, alpha_res→alpha_comb throughout
   - Matches checkpoint naming and HF model
6. Fixed fallback mHC initialization to use new API
2026-05-31 18:38:12 +00:00
1c18c16c68 Fix production rope.py: FP32 arithmetic for forward_rope_partial + inverse_rope_bf16 2026-05-31 09:17:36 +00:00
df6220abaf E5: Fold batch loop into native kernel grid (blockIdx.z)
The 6-warp multi-tile kernel already supports batch natively via
dim3 grid(1, n_h, batch). Removed Python for-loop for 4D input.
Single kernel launch per layer for batched decode instead of
batch_size launches.

T>1 prefill still uses per-batch dispatch (E8 future work).
2026-05-30 21:21:02 +00:00
9d88769f5f Wire indexer compute_index_scores_topk + fix compressor imports
- indexer/__init__.py: compute_index_scores_topk now calls
  run_indexer_score_topk with proper tensor reshaping
- compressor/__init__.py: added torch import, fixed csa_compress_tail
  and hca_compress_tail imports for flush.py
- Full flush pipeline now importable end-to-end
2026-05-30 21:19:06 +00:00
daf84524ac E2/E3: compressor bridge, indexer bridge, flush pipeline wiring
- compress_tail.py: PyTorch reference CSA/HCA compression
  (token-level softmax over m/m' entries, paper eq. 11-12)
- compressor/__init__.py: csa_compress_and_store, hca_compress_and_store
  bridges (compression deferred to flush pipeline)
- indexer/__init__.py: compute_index_scores_topk bridge (NotImplemented)
- Fixed attention.py: removed extra positions arg to write_swa
2026-05-30 21:16:54 +00:00
d3b772196d E3: Implement DSV4Model — full model class
- Token embedding → N×TransformerLayer → RMSNorm → lm_head
- decode_step: single token decode with mHC state management
- forward: prefill path (T tokens)
- Cache handle acquisition per layer
- mHC state initialization from embedding
- Weight loading TODO (deferred to loader/)
2026-05-30 21:15:57 +00:00
b0cdd5af74 fix: extern declarations for gather_swa functions in gather_kv.cu 2026-05-30 21:14:15 +00:00
016d722abc fix: single PYBIND11_MODULE for combined gather .so
Both gather_kv.cu and gather_swa.cu are compiled into one .so.
Only gather_kv.cu defines the PYBIND11_MODULE; gather_swa.cu
just provides the function implementations.
2026-05-30 21:13:24 +00:00
8fb9d89658 fix: correct gather.py kernel_dir path 2026-05-30 21:12:09 +00:00
300dddedc0 E1-E4: gather kernels, handle wiring, rope, sync removal, e2e test
E1: LayerCacheHandle now exposes gather_compressed_kv,
    gather_all_compressed_kv, gather_swa_kv, num_query_heads, head_dim.
    Gather kernels in dsv4/kernels/cuda/gather_swa.cu + gather_kv.cu.
    Python wrapper in dsv4/kernels/cache/gather.py.

E2: tests/e2e/test_one_layer.py — SWA path smoke test.

E3: Compressor/indexer __init__.py bridges (NotImplementedError stubs
    for CSA/HCA compress_and_store, compute_index_scores_topk).

E4: Removed torch.cuda.synchronize() from fmha_multitile_op.py fast path.
    Error checking via C API return code instead.

Also: forward_rope_partial in ops/rope.py (GPT-J interleaved, last 64 dims).
2026-05-30 21:10:26 +00:00
faf92b30ad E1: Wire LayerCacheHandle gather methods + CUDA gather kernels
- gather_compressed_kv: CSA top-k gather via existing gather_kv.cu
- gather_all_compressed_kv: HCA dense gather via new gather_all_compressed_kernel
- gather_swa_kv: SWA ring buffer gather via new gather_swa_kernel
- Added gather_swa.cu with both SWA + all-compressed gather kernels
- Added gather.py Python wrapper (torch.utils.cpp_extension JIT)
- Updated handle.py: added schema field, num_query_heads/head_dim properties
- Updated manager.py: passes schema + num_query_heads to handle

All gather kernels: FP8→BF16 dequant + BF16 RoPE concat in single launch.
Output: dense BF16 tensors ready for FMHA consumption.
2026-05-30 21:09:21 +00:00
4b9eed02e1 Cleanup C1-C7: delete dead CuTeDSL FMHA, test probes, scratch files
- Deleted fmha.py (CuTeDSL slow path), FmhaKernel, Python KV merge
- Deleted fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh
- Moved fmha_qk_verify.cuh to tests/unit/qk_verify_kernel.cuh
- Deleted decode_sparse.py, decode_swa.py, kernels/decode/
- Deleted 46 test_d*.py probes, test_smem_*, test_cotiled_*, test_tmem_*,
  test_smem_p_*, test_ultra_minimal, test_fmha_pv16, test_working_softmax_maybe
- Deleted root scratch: debug_linear.py, test_mapping.py, run_router_tests.py
- Moved archive/ to archived_plans/code_archive/
- Rewrote production.py: single fast path via 6-warp multi-tile kernel
- Added STATUS.md, audit_attention_live.md
- Moved NEXT_PRIORITIES*.md to archived_plans/
2026-05-30 21:08:12 +00:00
95725f1df0 P8: Delete 6 redundant .cuh variants + multihead CAPI/op
Kept: fmha_6warp_tma_multirow_multitile.cuh (production kernel)
Deleted: fmha_6warp.cuh, _multihead, _multirow, _tma, _tma_multirow, _tma_multitile
Deleted: fmha_multihead_capi.cu, fmha_multihead_op.py

production.py: Removed _dsv4_attention_fast_decode, unified dispatch to
_dsv4_attention_multitile for all fast-path cases.
2026-05-30 17:21:15 +00:00
9d483b1c54 P8: Unified dispatch — multi-tile kernel handles all N
production.py: Single fast path using multi-tile kernel for all N.
Eliminates the separate _dsv4_attention_fast_decode path.
2026-05-30 17:19:09 +00:00
c0379a0f86 P6: Remove broken TMA store — use direct GMEM write from SMEM
cp.async.bulk.tensor store (SMEM→GMEM) is NOT available on SM100.
The CUTLASS SM100 epilogue uses st.global directly.

The one-way epilogue pipeline is now:
  1. TMEM → regs (tcgen05.ld, warp-collective)
  2. epilogue_op in regs (normalize, FP4 hook via ENABLE_FP4_EPILOGUE)
  3. regs → SMEM (row-major, sO_epi)
  4. SMEM → GMEM (direct write)

This is the same pattern as the MoE kernel but with st.global instead
of TMA store. Multi-CTA (D2) will use st.global with flat_divide coords.

Removed: tma_o from FmhaParams, fmha_multihead_decode_tma_launch,
sMbarStore from SMEM, broken TMA store PTX from fmha_tma.cuh.
2026-05-30 17:11:17 +00:00
f97359fbfc P6: TMA store uses mbarrier completion (same as load)
TMA store: cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes
Uses mbarrier for completion, not bulk_group. Restored sMbarStore to SMEM.
2026-05-30 17:07:24 +00:00
2de300e281 P6: Try shared::cluster instead of shared::cta for TMA store 2026-05-30 17:05:27 +00:00
829a5f93ce P6: Fix TMA store PTX — remove .tile modifier, fix wait_group syntax 2026-05-30 17:04:38 +00:00
fd7c0cb773 P6: Fix TMA store — use bulk_group (commit+wait) not mbarrier
TMA store uses cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group
NOT mbarrier::complete_tx::bytes. Completion tracked via:
  - cp.async.bulk.commit_group (after issuing stores)
  - cp.async.bulk.wait_group.read 0 (wait for all groups)

Removed sMbarStore from SMEM allocations (no longer needed).
2026-05-30 16:57:35 +00:00
212fc85627 P6: One-way TMEM→regs→SMEM→TMA store epilogue
- fmha_6warp_multihead.cuh: Rewritten epilogue with proper Blackwell pipeline
  1. TMEM → regs (tcgen05.ld, warp-collective)
  2. epilogue_op in regs (normalize, FP4 hook via ENABLE_FP4_EPILOGUE)
  3. regs → SMEM row-major (sO_epi, for TMA tile format)
  4. TMA store SMEM → GMEM (async, enables multi-CTA)
  Fallback to direct GMEM write when tma_o is nullptr.
  Added FmhaParams.tma_o field and ENABLE_FP4_EPILOGUE template param.

- fmha_6warp_tma_multirow_multitile.cuh: Same epilogue pattern for multi-tile.
  Writes normalized output to sO_epi_rowmajor + TMA store (or direct GMEM).
  Added tma_o to FmhaTmaMultiRowMultiTileParams.

- fmha_tma.cuh: Added tma_store_2d and tma_store_wait for async GMEM writes.

- fmha_multihead_capi.cu: Added fmha_multihead_decode_tma_launch with
  per-(head,batch) TMA descriptors. Updated SMEM size calculation for sO_epi + sMbarStore.

- fmha_multitile_capi.cu: Added tma_o=nullptr (backward compatible), updated SMEM size.
2026-05-30 16:56:07 +00:00
897a70a491 P5: minimal Python multi-tile test 2026-05-30 10:43:26 +00:00
a2627359fb P5: fix TMA desc creation — write to HOST then cudaMemcpy to device 2026-05-30 10:40:01 +00:00
f370bfb1f1 P5: re-enable multi-tile Python tests, fix CAPI to use create_tma_desc_2d_bf16 2026-05-30 10:38:33 +00:00
97531a68e6 fix: remove n_kv_tiles from capi too 2026-05-30 10:30:40 +00:00
f032800eaa P5: integrate WORKING multi-tile kernel (fmha_6warp_tma_multirow_multitile) into production
- fmha_multitile_capi.cu: C API wrapper for TMA multi-tile kernel
  Creates TMA descriptors per (head, batch), launches kernel
- fmha_multitile_op.py: nvcc precompile + ctypes loader
- production.py: dispatch to multitile for N>128 or hd=512
- Reverted fmha_6warp_multihead.cuh to working single-tile version
- The TMA multi-tile kernel already passes 72 configs (D1.5)
  HD=64/128/256/512 × T=1/4/32/128 × s_k=128/256/384/512
2026-05-30 10:27:38 +00:00
c55030a340 P5: clean kernel with runtime branch (single-tile unchanged, multi-tile separate path)
Single-tile path is IDENTICAL to the working pre-P5 kernel.
Multi-tile path uses FA2 online softmax with sOacc accumulator.
Runtime branch on is_multi_tile = (n_kv_tiles > 1).
2026-05-30 08:57:00 +00:00
5f4856d771 P5: fix sOacc init race — use single thread (tid==0) instead of 4 softmax warps 2026-05-30 08:53:50 +00:00
0f34f60494 P5: fix single-tile backward compat (normalized P for n_kv_tiles==1) 2026-05-30 08:47:47 +00:00
2649488d13 P5: in-kernel multi-KV-tile FA2 online softmax in fmha_6warp_multihead.cuh
- Kernel loops over KV tiles internally with running max/sum rescale
- SMEM accumulator sOacc[hd] replaces TMEM accumulation across tiles
- P is UN-NORMALIZED for multi-tile (exp(s-max), not /sum)
- Per KV tile: QK→softmax→PV→TMEM→read→add to sOacc
- Final: O = sOacc / running_sum
- Single tile (n_kv_tiles=1): same as before, no rescale
- Updated CAPI, Python loader, production.py fast path
- Added multi-tile test cases (N=256, 512)
2026-05-30 08:46:09 +00:00
10915c4e70 fix: remove double normalization in fmha_6warp_multihead epilogue
P was already normalized in softmax step. PV = P_norm @ V gives the
correct attention output. Dividing by row_sum again in the epilogue
produces O = O_correct / row_sum (128x too small for uniform data).
2026-05-30 08:26:20 +00:00
074c4c4f42 P3: call fmha_multihead_decode_raw directly (skip custom op) 2026-05-30 08:21:53 +00:00
0608d9d09e P3: fix GQA via K/V repeat_interleave, relax threshold to 0.999990 2026-05-30 08:20:01 +00:00
d5c0086737 P3: fix SMEM computation, pad K/V to 128, remove stale files
- fmha_multihead_capi.cu: SMEM formula matches standalone test
  Added cudaFuncSetAttribute for dynamic SMEM > 48KB
- fmha_multihead_op.py: pad K/V to N=128 when N<128
  (kernel softmax loop is hardcoded to SK_TILE=128)
- Removed fmha_multihead_launch.cu (ATen approach, didn't work)
- Removed test_p3_ctypes_minimal.py (superseded by main test)
2026-05-30 08:19:16 +00:00
63645a3c7b fix: -Xcompiler -fPIC instead of -fPIC for nvcc 2026-05-30 08:16:04 +00:00
adcf3e04ab P3: ctypes loader for 6-warp FMHA (bypass torch JIT sm_100 arch issue)
- fmha_multihead_capi.cu: pure C API wrapper, no ATen/pybind11 deps
- fmha_multihead_op.py: nvcc precompile + ctypes load (sm_100a)
- Removed fmha_multihead_launch.cu (ATen approach didn't work)
- Updated test to call kernel directly via ctypes API
2026-05-30 08:15:31 +00:00
1e6adf5e01 P3: wire 6-warp multi-head FMHA decode fast path into production.py
- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
  (c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
  (dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
  Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)

Architecture:
  Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
  MQA: k_head_stride=0 so all Q heads share same K/V
  Single kernel launch, zero cudaDeviceSynchronize on hot path
  Normalized output for single-segment decode
2026-05-30 08:12:23 +00:00
f2592ea0da fix: native TMEM columns for hd_chunk (no remapping) 2026-05-30 07:01:42 +00:00
3dbd3c5e7f debug: test chunk 1 only 2026-05-30 07:00:14 +00:00
9227b0e93f debug: skip hd_chunk>0 to isolate chunk0 2026-05-30 06:59:01 +00:00