Commit Graph

11 Commits

Author SHA1 Message Date
bb3ad3d2ef BREAKTHROUGH: cosine 0.993 for n=128! PV-partitioned P row sum works.
C9 fix: instead of using QK-partitioned row_sum (which maps to wrong PV rows),
read P from TMEM using PV partition and sum via .reduce(ADD).

QK: thread N owns row N//4, PV: thread N owns row N.
Reading P via PV partition gives each thread its correct row P values.

n=128: cosine 0.993 (was 0.514)
n=256: cosine 0.725 (C6 still broken for multi-tile)
n=384: cosine 0.676 (same C6 issue)

Remaining: C6 O-rescale for multi-tile needs same PV-partitioned fix.
Small accuracy gap (0.993 vs 0.999) likely from BF16 P store/load round-trip.
2026-05-21 20:13:51 +00:00
7d1c402a6d WIP: TMEM vector bridge not working (same cosine 0.513)
row_sum is PROVEN correct (29.25 vs 29.22 for row 0, ratio 1.001).
The ONLY bug is QK→PV row mapping in C9 normalization.

Tried: composition(tStS,(128,1)) for write, composition(tOtO,(128,1)) for read.
Same result — the composition preserves the fragments internal thread-to-address
mapping, so the same thread writes and reads the same TMEM address regardless
of which fragment layout is used for the composition.

Need: absolute row-coordinate indexed TMEM vector. Each QK thread writes
inv_row_sum to vec[QK_row_id], each PV thread reads from vec[PV_row_id].
The row_id comes from the identity tensor coordinate.

Alternative: implement FMHA correction_epilog pattern with dedicated
correction warp group that reads row metadata from the vector.
2026-05-21 19:26:15 +00:00
cae87fd744 WIP: confirmed row_sum is wrong (5.5 vs correct 29.22 for row 0)
The packed f32x2 reduction SHOULD sum all 128 exp2 P values but gives
a result ~5.3x too small. Need to debug inside the kernel with print
statements to see what values the reduction is actually summing.

Unnormalized P@V is perfect (cosine 0.999998). row_max is correct
(because P is correct). The bug is specifically in row_sum computation.
2026-05-21 19:16:15 +00:00
c09c660110 WIP: scalar C9 normalization - confirmed inv_row_sum is wrong
The C9 TMEM round-trip IS modifying O (confirmed by epilogue * 2.0 test).
But inv_row_sum is wrong: each thread computes row_sum via .reduce(MAX) and
packed f32x2 reduction, but the result appears to be the same for all threads.

Next: need to dump the QK C-fragment coordinate tensor to understand
which rows each thread actually owns in the TMEM load partition.
2026-05-21 19:09:32 +00:00
ce91aa26e4 WIP: QK-partitioned C9 normalization (does not work)
The QK composition(tStS, (128,64)) view of O TMEM region does not align
with the actual PV C-fragment layout. Cannot read O with QK partition.

Need to use TMEM vector approach:
1. Store inv_row_sum via QK partition (composition(tStS, (128,1)))
2. Read inv_row_sum via PV partition (need PV-partitioned view of vector)
3. Apply normalization in PV-partitioned O TMEM access

The key challenge: creating a PV-partitioned read of the vector TMEM region
that was written with QK partition. This is what CUTLASS FMHA does with
its correction warp group.
2026-05-21 18:59:21 +00:00
1fa093ee12 BREAKTHROUGH: unnormalized P@V cosine 0.999998 for n=128!
The softmax math (exp2, P store, PV) is correct for single-tile.
The bug is ONLY in C6/C9 normalization: applying inv_row_sum
using PV partition instead of QK partition.

n=128 (single tile): cosine 0.999998 PASS
n=256/384 (multi-tile): C6 O-rescale using wrong partition = FAIL

Fix: normalize O using QK row coordinates, not PV row coordinates.
Can use TMEM vector to bridge QK partition to PV partition.
2026-05-21 18:55:00 +00:00
c2901b2ecc WIP: TMEM vector for per-row row_sum (not yet working)
Key finding: the root cause is that each epilogue thread owns MULTIPLE rows
in the QK C-fragment, so scalar row_max/row_sum are wrong (global across
all rows, not per-row). The V=ones diagnostic confirmed: all 128 threads
use the same row_sum (from row 114).

Tried: TMEM vector store+load of row_sum (composition(tStS, (128,2))).
This is a no-op because both write and read use the SAME QK partition
with a scalar row_sum. The vector approach only helps when different
partitions are used for write vs read, or when per-row values are stored.

Next steps:
1. Need PER-ROW row_max and row_sum, not per-thread scalar
2. The CUTLASS FMHA works because each thread owns exactly 1 row
3. Options: restructure thread layout, or compute per-row values differently
4. The vector must store ALL 128 per-row values, then read per-row in C9
2026-05-21 18:45:30 +00:00
4c203809ef WIP: Stage C softmax - partial progress
Key finding: cute.size(v, mode=[0]) in @cute.jit produces wrong code.
Hardcoding s_k=128 (matching Stage B) fixes the base pipeline.

Current status: kernel produces non-zero output but softmax math is still wrong.
Applied fixes: pv_done_bar, acc_scale with scale, fastmath=True
Need to debug row_sum computation and C9 normalization.
2026-05-21 18:04:21 +00:00
8e1facef01 Stage C fixes: pv_done_bar sync, acc_scale with scale, fastmath=True
- Add pv_done_bar (barrier_id=4): MMA signals PV complete, epilogue
  waits before O rescale (C6) and final normalization (C9)
- Fix acc_scale: exp2(scale * (old_max - new_max)) includes the
  scale_softmax_log2 factor matching CUTLASS FMHA reference
- fastmath=True for both exp2 calls (P computation + rescale)
- No *0.5 (our scalar row_sum pattern initializes (0,0) not (sum,sum))
2026-05-21 17:58:04 +00:00
58ca480fd1 Stage C: add validation harness with real softmax reference (C1) 2026-05-21 17:49:26 +00:00
9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00