Commit Graph

539 Commits

Author SHA1 Message Date
bb3ad3d2ef BREAKTHROUGH: cosine 0.993 for n=128! PV-partitioned P row sum works.
C9 fix: instead of using QK-partitioned row_sum (which maps to wrong PV rows),
read P from TMEM using PV partition and sum via .reduce(ADD).

QK: thread N owns row N//4, PV: thread N owns row N.
Reading P via PV partition gives each thread its correct row P values.

n=128: cosine 0.993 (was 0.514)
n=256: cosine 0.725 (C6 still broken for multi-tile)
n=384: cosine 0.676 (same C6 issue)

Remaining: C6 O-rescale for multi-tile needs same PV-partitioned fix.
Small accuracy gap (0.993 vs 0.999) likely from BF16 P store/load round-trip.
2026-05-21 20:13:51 +00:00
7d1c402a6d WIP: TMEM vector bridge not working (same cosine 0.513)
row_sum is PROVEN correct (29.25 vs 29.22 for row 0, ratio 1.001).
The ONLY bug is QK→PV row mapping in C9 normalization.

Tried: composition(tStS,(128,1)) for write, composition(tOtO,(128,1)) for read.
Same result — the composition preserves the fragments internal thread-to-address
mapping, so the same thread writes and reads the same TMEM address regardless
of which fragment layout is used for the composition.

Need: absolute row-coordinate indexed TMEM vector. Each QK thread writes
inv_row_sum to vec[QK_row_id], each PV thread reads from vec[PV_row_id].
The row_id comes from the identity tensor coordinate.

Alternative: implement FMHA correction_epilog pattern with dedicated
correction warp group that reads row metadata from the vector.
2026-05-21 19:26:15 +00:00
cae87fd744 WIP: confirmed row_sum is wrong (5.5 vs correct 29.22 for row 0)
The packed f32x2 reduction SHOULD sum all 128 exp2 P values but gives
a result ~5.3x too small. Need to debug inside the kernel with print
statements to see what values the reduction is actually summing.

Unnormalized P@V is perfect (cosine 0.999998). row_max is correct
(because P is correct). The bug is specifically in row_sum computation.
2026-05-21 19:16:15 +00:00
8eb569e31c BREAKTHROUGH: Found the real bug!
Each epilogue thread owns only 32 of 128 columns per row (4 warps partition
the 128 columns). The scalar row_max is max of 32 elements (not global row max).
The scalar row_sum is sum of 32 P values (not all 128).

Evidence: kernel row_sum=5.5 for row 0, correct=29.22, ratio=5.3x
(128/32 = 4, but the P values arent uniformly distributed so not exact).

Fix: need warp reduce across the 4 threads that own the same rows columns.
CUTLASS FMHA uses utils.sm100 warp-reduce helpers for this exact reason.
2026-05-21 19:10:47 +00:00
c09c660110 WIP: scalar C9 normalization - confirmed inv_row_sum is wrong
The C9 TMEM round-trip IS modifying O (confirmed by epilogue * 2.0 test).
But inv_row_sum is wrong: each thread computes row_sum via .reduce(MAX) and
packed f32x2 reduction, but the result appears to be the same for all threads.

Next: need to dump the QK C-fragment coordinate tensor to understand
which rows each thread actually owns in the TMEM load partition.
2026-05-21 19:09:32 +00:00
ce91aa26e4 WIP: QK-partitioned C9 normalization (does not work)
The QK composition(tStS, (128,64)) view of O TMEM region does not align
with the actual PV C-fragment layout. Cannot read O with QK partition.

Need to use TMEM vector approach:
1. Store inv_row_sum via QK partition (composition(tStS, (128,1)))
2. Read inv_row_sum via PV partition (need PV-partitioned view of vector)
3. Apply normalization in PV-partitioned O TMEM access

The key challenge: creating a PV-partitioned read of the vector TMEM region
that was written with QK partition. This is what CUTLASS FMHA does with
its correction warp group.
2026-05-21 18:59:21 +00:00
1fa093ee12 BREAKTHROUGH: unnormalized P@V cosine 0.999998 for n=128!
The softmax math (exp2, P store, PV) is correct for single-tile.
The bug is ONLY in C6/C9 normalization: applying inv_row_sum
using PV partition instead of QK partition.

n=128 (single tile): cosine 0.999998 PASS
n=256/384 (multi-tile): C6 O-rescale using wrong partition = FAIL

Fix: normalize O using QK row coordinates, not PV row coordinates.
Can use TMEM vector to bridge QK partition to PV partition.
2026-05-21 18:55:00 +00:00
c2901b2ecc WIP: TMEM vector for per-row row_sum (not yet working)
Key finding: the root cause is that each epilogue thread owns MULTIPLE rows
in the QK C-fragment, so scalar row_max/row_sum are wrong (global across
all rows, not per-row). The V=ones diagnostic confirmed: all 128 threads
use the same row_sum (from row 114).

Tried: TMEM vector store+load of row_sum (composition(tStS, (128,2))).
This is a no-op because both write and read use the SAME QK partition
with a scalar row_sum. The vector approach only helps when different
partitions are used for write vs read, or when per-row values are stored.

Next steps:
1. Need PER-ROW row_max and row_sum, not per-thread scalar
2. The CUTLASS FMHA works because each thread owns exactly 1 row
3. Options: restructure thread layout, or compute per-row values differently
4. The vector must store ALL 128 per-row values, then read per-row in C9
2026-05-21 18:45:30 +00:00
4c203809ef WIP: Stage C softmax - partial progress
Key finding: cute.size(v, mode=[0]) in @cute.jit produces wrong code.
Hardcoding s_k=128 (matching Stage B) fixes the base pipeline.

Current status: kernel produces non-zero output but softmax math is still wrong.
Applied fixes: pv_done_bar, acc_scale with scale, fastmath=True
Need to debug row_sum computation and C9 normalization.
2026-05-21 18:04:21 +00:00
8e1facef01 Stage C fixes: pv_done_bar sync, acc_scale with scale, fastmath=True
- Add pv_done_bar (barrier_id=4): MMA signals PV complete, epilogue
  waits before O rescale (C6) and final normalization (C9)
- Fix acc_scale: exp2(scale * (old_max - new_max)) includes the
  scale_softmax_log2 factor matching CUTLASS FMHA reference
- fastmath=True for both exp2 calls (P computation + rescale)
- No *0.5 (our scalar row_sum pattern initializes (0,0) not (sum,sum))
2026-05-21 17:58:04 +00:00
58ca480fd1 Stage C: add validation harness with real softmax reference (C1) 2026-05-21 17:49:26 +00:00
e8485b9cf5 README: add full DSV4 pipeline architecture diagram (CSA/HCA, not MLA) 2026-05-21 17:40:25 +00:00
364d9edcd3 README: update for new dsv4/ package structure 2026-05-21 17:34:40 +00:00
9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00
94b26ebdf0 Fix: add scale_softmax_log2, use O TMEM rescale for C9 normalization
- scale_softmax_log2 was missing from _setup (patch artifact)
- C9 normalization: load O from TMEM, multiply by 1/row_sum, store back
  instead of trying to capture runtime value in const_expr lambda
- Then use standard epilogue_tma_store with identity transform
2026-05-21 17:15:15 +00:00
3f6f4a3ad8 Stage C: online softmax kernel (WIP) - test_fmha_v3_softmax.py
- C1: Real softmax reference (torch.softmax, not identity)
- C2: Per-thread row_max/row_sum registers
- C3: QK scale folded (1/sqrt(d) * log2(e))
- C4: Row max via .reduce(MAX)
- C5: Rescale factor (exp2(old_max - new_max))
- C6: O rescale in TMEM (correction_rescale pattern)
- C7: Real exp2 for P computation
- C8: Row sum via packed f32x2 reduction
- C9: Final normalization (1/row_sum in epilogue)
- Dynamic s_k for V FMHA reconstruction
- fastmath=False for correctness first
2026-05-21 17:10:58 +00:00
ed712c4939 README: full roadmap — Stage C (real softmax), D (paged KV), E (production kernel)
Document canonical test files, obsolete test sprawl, and the path from
test_fmha_v3.py → cutedsl/kernel/attention/fmha_kernel.py → vLLM integration.
Also: TMEM layout for Stage C, key lessons from A&B.
2026-05-21 15:43:01 +00:00
b4f6c6b702 Update both READMEs: Stage B complete, document TMEM overlap root cause
- Workspace README: full rewrite with Stage B , Bug 4b root cause (P/O overlap),
  FMHA V reconstruction, TMEM layout diagram, softmax store pattern, updated footguns
- Kernel README: focused on the bug, fix, and current test status
- Key lesson documented: NEVER use find_tmem_tensor_col_offset() as O placement
2026-05-21 15:36:06 +00:00
2d10319e9f Fix TMEM overlap in test_pv64_with_softmax.py too — cosine 0.999999
Same P/O overlap bug: O at col 64 overlapped P at [32,96).
Same fixes: O at col 128, FMHA V reconstruction, power-of-2 TMEM alloc.
2026-05-21 15:32:49 +00:00
73f38acf74 STAGE B BUG 4b FIXED: TMEM P/O overlap + FMHA V reconstruction
Root cause: PV output O started at TMEM column 64 (from find_tmem_tensor_col_offset),
overlapping with P at columns [32,96). PV MMA reading P while writing O to overlapping
columns corrupted the A operand mid-computation.

For (128,128) PV, O started at 128 (no overlap) so it worked by accident.
For (128,64) PV, O started at 64, overlapping P [32,96) -> NaN/garbage.

Fix: Place O at column 128 (after both S [0,128) and P [32,96)).
Also added FMHA-style V reconstruction: logical (HEAD_DIM, s_k, 1) stride (1, hd, hd*s_k)
instead of passing DLPack V directly to TMA.

test_fmha_v3.py: (128,64) PV with random V -> cosine 0.999999 PASS
2026-05-21 15:30:24 +00:00
dba4279364 Stage B Bug 4b debugging: P/A alias proven working, V layout issue for (128,64) PV
Key findings:
- P/A alias WORKS: PV reads non-zero P from TMEM at offset 32 (proven by no-softmax test)
- V mode bug: V=(128,64) only loads 64 K-values, PV needs 128. Output = sum(S[:,:64]) = 0.67 cosine
- FMHA-style V reconstruction (hd,n,1) stride (1,hd) gives NaN for (128,64) PV
- K-major V (64,128) contiguous gives NaN for (128,64) PV
- Square (128,128) PV works with ALL V approaches (cosine 0.999999)
- Non-square PV consistently broken regardless of V layout

Test files:
- test_128_128_fmha_v.py: (128,128) with FMHA V - PASS
- test_pv64_fmha_v.py: (128,64) with FMHA V - NaN
- test_pv64_kmajor_v.py: (128,64) with K-major V - NaN
- test_pv64_with_softmax.py: (128,64) with original V - 0.67
- test_pv64_no_softmax.py: proves P/A alias works
- test_fmha_v3.py: full pipeline with QK C-fragment composition store
2026-05-21 15:20:14 +00:00
5b9d0280d9 FMHA v3: KV-tile interleaving pipeline - QK works, Bug 4b blocks PV 2026-05-21 12:52:29 +00:00
b0912fbf85 Stage B: PV(128,64) test + v2 pipeline fixes
- test_pv64.py: (128,64) PV with separate V SMEM, single ab pipeline
  Result: cosine 0.669848 — data path works but P layout mismatch
  Softmax writes P via QK C-fragment layout, PV reads via PV A-fragment layout
  These differ for non-(128,128) PV — Bug 1 from README

- test_fmha_v2_fixed.py: KV-tile interleaved pipeline with fixes
  Fix 1: per-pipeline tx_count (Q vs KV separate byte counts)
  Fix 2: NamedBarrier for softmax-done signal (replaces double-acquire deadlock)
  Fix 3: Separate SMEM for V (no recast_ptr overlap with K)
  Still produces zeros — needs P layout fix (same root cause as test_pv64)
2026-05-21 11:49:06 +00:00
579b2f6abe stuff and stuff 2026-05-21 10:50:30 +00:00
04cbea072e FMHA v1: pv_mma_tiler=(128,64,128) works with V=I, fails with real V (SMEM layout bug) 2026-05-21 10:47:46 +00:00
241ac2bf94 README: Bug 4 ROOT CAUSE CONFIRMED - V SMEM 1 K-tile + PV 8 K-phases mismatch. Zero-pad V workaround correct. 2026-05-21 09:59:37 +00:00
6c0a6cf50b Root cause FOUND: V SMEM only holds 1 K-tile (2048 BF16), but PV MMA iterates 8 K-phases. For non-(128,128) PV, most K-phases read wrong V data. Zero-padded V works because V is (128,128) covering all 8 K-phases. FMHA interleaves QK+PV per KV-tile to avoid this. 2026-05-21 09:56:54 +00:00
c59974bcef README: Bug 4 corrected — NOT TMEM alias. A-fragment identical for all PV sizes. Real bug in V/B or output C/D. 2026-05-21 09:47:08 +00:00
80da5c51a6 Key finding: PV A-fragment layout is IDENTICAL for (128,128)/(128,32)/(128,16) PV. Bug is NOT TMEM alias. cta_tile_shape_mnk wrong for non-(128,128) PV. V SMEM and O C-fragment sizes look correct. Debugging V/epilogue paths. 2026-05-21 09:44:22 +00:00
464390fcfc Update README: Bug 4 status, (128,16) PV zero output, (128,128) PV zero-pad workaround (cosine 1.0) 2026-05-21 09:20:09 +00:00
6a352cf6da TMEM alias analysis: (128,16) PV broken, (128,128) PV with zero-pad works. Root cause: PV A-fragment layout differs from QK C-fragment layout for (128,16) PV, causing TMEM column mismatch. Using (128,128) PV as workaround. 2026-05-21 09:10:12 +00:00
73cb3a3277 Debugging TMEM alias for (128,16) PV: zero output confirmed, PV reads from wrong TMEM columns. Need to align softmax P write with PV A-fragment layout. 2026-05-21 09:00:42 +00:00
768a696373 Stage B N-tiling: (128,16) PV MMA compiles and runs, cosine 0.36 (TMEM alias mismatch bug). FMHA head_dim=64 passes. Debugging TMEM layout alignment. 2026-05-21 08:45:49 +00:00
f3a7dc1598 FOOTGUN #0: num_tma_load_bytes MUST include V bytes. Fix v27, v29, comment all. Update README. 2026-05-21 07:13:14 +00:00
d873284e84 v29: FIX DEADLOCK - add V bytes to num_tma_load_bytes. V=I(128,128) cosine 1.0 2026-05-21 07:08:29 +00:00
21c9a4823e README: update with v28/v29 deadlock investigation, FMHA softmax bridge trace, new footguns 2026-05-21 06:46:02 +00:00
6ec308dd50 v29 (padded V, deadlocks), v30 (diag copy, works) — debugging epilogue deadlock with (128,128) PV 2026-05-21 06:40:27 +00:00
261fe5e43e even more stuff 2026-05-21 05:55:22 +00:00
3fdb9f008b v28 attempt: PV MMA (128,64) - cosine 0.004, debugging 2026-05-21 05:41:44 +00:00
ce63ec7ee9 README: Bug 4 root cause — TMEM layout mismatch (128,64) PV A-fragment vs softmax P write
- (128,64) PV MMA A-fragment has N_MMA=64, reads P with wrong stride
- Softmax writes P with QK C-fragment layout (N_MMA=128)
- O[m,d] ≈ P[m,2d] — every other column effect confirmed
- All-ones and single-element V pass (uniform/sparse data hides mismatch)
- epi_tile must use PV cta_tile (partial fix: 0.01 → 0.876)
- Added footguns #9 (TMEM alias N_MMA match) and #10 (epi_tile)
- Added diagnostic test results to test table
2026-05-21 05:17:12 +00:00
1f6fb856ea more stuff 2026-05-21 05:08:57 +00:00
1314a67135 Stage B progress: PV works for square (128,128), broken for (128,64)
- Bug 1 (V MN-major): Fix applied
- Bug 2 (softmax packing): Confirmed correct (V=I test: cosine 1.0)
- Bug 3 (ACCUMULATE): Fix applied (first PV must overwrite, not accumulate)
- Bug 4 (CURRENT): PV MMA broken for non-square output
  - (128,128) PV with random V: cosine 0.999999 
  - (128,64) PV with MN-major V: cosine ~0.01 
  - Softmax packing, layout aliasing, pipeline ordering all verified correct
  - Root cause unknown — likely epilogue/V layout/MMA tiler issue

Added test_pv_diag.py (V=I and random V, 128x128 output — PASS)
Added test_layout_compare.py (TMEM layout inspection)
Added test_inspect_types.py (TMEM pointer arithmetic verification)
Updated test_mma_si_pv.py with head_dim param, pv_mma_tiler_mn fix, ACCUMULATE fix
Updated READMEs with current state
2026-05-21 04:40:28 +00:00
2cbb0369fd Stage B: pipeline deadlock fixed, V MN-major applied, PV output garbage
Pipeline deadlock fixed:
- No cta_layout_vmnk on mma_si PipelineUmmaAsync
- TMA warp excluded from tmem.wait_for_alloc
- PipelineTmaStore (not TmaStorePipeline)

Bug 1 (V MN-major): fix applied
- PV MMA uses v_major=OperandMajorMode.MN
- V shaped (64,128) strides(1,64) via as_strided

Bug 2 (softmax packing): C-fragment composition store applied
- FP32 to BF16 packing works
- St32x32bOp uses Float32 (not BFloat16)

Bug 3 (PV garbage): investigating
- PV MMA cosine ~0.01 against reference
- Suspected TMEM layout mismatch between softmax P store and PV A-fragment read

Test results:
- test_mma_si_only: cosine 0.999999 PASS
- test_mma_si_pv: cosine 0.01 FAIL (pipeline works, PV output wrong)
2026-05-21 04:10:07 +00:00
1eca10e39f Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00
e678afcde0 Stage B: two MMAs + identity softmax — crash fixed, softmax output still wrong
Key fixes:
- PipelineUmmaAsync consumer group: 32*4=128 threads (not 4 warps)
- TMEM offsets computed from find_tmem_tensor_col_offset (not hardcoded)
- P fragment from p_tmem_s.outer + make_fragment_A (matching fmha.py)
- V SMEM aliasing via recast_ptr

Status:
- Stage A: cosine 0.999999 
- Stage B: runs without crash, identity softmax cosine -0.02 
- Diagnostics: TMEM layout inspection, bisection results
2026-05-20 20:26:25 +00:00
d3a7e7a286 stuff 2026-05-20 07:15:01 +00:00
1b4742a438 Update README: reflect current state, add C128A/C4A topk + warmup fixes 2026-05-20 06:51:12 +00:00
a793751fe8 Fix warmup compilation + add sparse topk metadata kernels
Bug #2 fix: warmup_compilation and warmup_fused_swiglu_compilation now
use valid FP4 data by quantizing random BF16 through quantize_to_nvfp4.
Random uint8 bytes as FP4 bit patterns cause cudaErrorIllegalInstruction
in Blackwell MMA hardware. Re-enabled warmup calls in runner.py.

Bug #1 kernel: sparse_topk_metadata.cu with:
  - build_c128a_topk_metadata: position-based compressed KV slot lookup
    via block table for C128A (compress_ratio=128) decode tokens
  - compute_c4a_global_topk: local topk index -> global slot ID mapping
    via block table for C4A (compress_ratio=4) decode tokens
  - Both tested: correct block table lookups, proper padding

Bug #3 kernel: C4A uses compute_c4a_global_topk (same .cu file)
  - Replaces vLLM Triton kernel with our own CUDA kernel

Deleted stale STATUS.md, FUSED_EPILOGUE_STATUS.md, FUSED_EPILOGUE_PLAN.md, CURRENT_BUGMD
2026-05-20 06:43:43 +00:00
dd7af0cd8a feat: GPU-native SWA + sparse decode attention kernels (CuTeDSL)
- native_swa_decode.py: BlackwellSWADecodeKernel
  - CTA mapping: 1 CTA per (decode_token, q_head_group)
  - Online softmax with KV tile streaming (16 tokens/tile)
  - Pre-dequantized bf16 KV (fp8 dequant on host - MLIR cvt_fpext
    requires 32-bit aligned vector, no scalar fp8->bf16 support)
  - Cosine 0.9999+ vs PyTorch batched SDPA reference
  - Fallback _fallback_batched_sdp when CuTeDSL unavailable

- native_sparse_decode.py: BlackwellSparseDecodeKernel
  - Combined SWA + compressed KV in single attention pass
  - Supports CSA (cr=4) and HCA (cr=128) layers
  - Sink weight merge on host side
  - Cosine 0.9999+ vs combined SDPA reference

- fp8_bf16.py: Documents MLIR limitation (cvt_fpext requires
  vector<4xf8>, no scalar support). Pre-dequant is the workaround.

- vLLM wiring (attention.py):
  - SWA-only layers: native_swa_decode_attention
  - CSA/HCA layers: native_sparse_decode_attention with topk + attn_sink
  - csa_attention.py updated to use native kernels

- Tests: test_decode_pipeline.py, test_sparse_decode.py both passing
2026-05-20 05:46:15 +00:00
16b3094bdb README: comprehensive update with current kernel status 2026-05-20 04:42:57 +00:00