Commit Graph

217 Commits

Author SHA1 Message Date
2147cce95d Stage C: integrate example3 multi-tile fixes into unit test
- Combined K+V barrier (one acquire per kt, kvh.count == kt)
- O rescale for kt > 0 (online softmax O correction)
- final_o_bar sync (MMA signals before producer_tail)
- s_k as constructor param (compile-time for V layout)
- kv_tx_bytes covers both K and V transfers
- Test covers n=128, 256, 512, 1024
2026-05-22 16:39:45 +00:00
e4c82873bb FMHA Stage-C multi-tile: combined K+V barrier, final_o_bar, acc_pipe producer
Key changes from Mike:
1. Combined K+V TMA barrier: one acquire per kt, both cute.copys share
   kvh.barrier. kvh.count naturally == kt (no interleaving problem).
   tx_count = K_bytes + V_bytes. Also fixes the sK[0]/sV[1] slot quirk.
2. final_o_bar NamedBarrier: MMA .arrive() after acc_pipe.producer_tail;
   softmax .arrive_and_wait() before reading O for normalize. Prevents
   softmax racing MMA's PV[N-1] on the final O read.
3. acc_pipe producer in MMA: producer_acquire before loop, commit+advance
   after loop, producer_tail after. Consumer in epilogue as before.
4. O rescale re-enabled for kt>0 with acc_scale before softmax_done_bar.
2026-05-22 16:23:36 +00:00
39fa5b96b0 restore tBgK to kh.count indexing (single-tile working), add TODO for multi-tile
CuTeDSL TMA copy API doesn't support dynamic GMEM tile indexing.
kh.count works for single tile. For multi-tile, need to either:
1. Map pipeline count to tile index (kh.count // 2 for interleaved K/V)
2. Separate K and V into non-interleaved TMA loops
3. Use gK/gV layouts that iterate naturally with pipeline count

This is the architectural blocker for multi-tile FMHA.
2026-05-22 15:54:03 +00:00
f1854bab26 FIX: use unsliced tBgK with (None, kt, None, 0) for proper GMEM tile indexing
The pre-slice (None,0,None,0) hardcoded GMEM iteration to tile 0.
Instead, keep the original tBgK and index with (None, kt, None, 0)
inside the TMA loop, where kt selects the correct GMEM tile.
This preserves 2D rank matching with the SMEM tensor.
2026-05-22 15:52:56 +00:00
3d2cb0e52b CRITICAL FIX: keep GMEM iteration dim free in tBgK/tVgV slice
The slice (None,0,None,0) was hardcoding the GMEM iteration dim to 0,
meaning TMA always loaded K/V from tile 0 regardless of kt.
Changed to (None,None,None,0) to keep gmem_iter free,
then index with (None, kt, None) in the TMA copy loop.

This is the root cause of multi-tile failure: TMA was always reading
the first 128 tokens for ALL KV tiles.
2026-05-22 15:52:06 +00:00
a04b219f0f add explicit acc_pipe.consumer_wait before final normalize
Race condition: softmax reads O to normalize while MMA may still be
writing PV[N-1]. Single-tile wins by luck; multi-tile drifts.
Move acc_cons_st construction before the wait so epilogue reuses it.
2026-05-22 15:49:48 +00:00
ff27a261b1 FMHA Stage-C multi-tile: Fix 1 (s_k=n), Fix 2 (TMA kt indexing), Fix 3 (O rescale)
Fix 1: s_k must equal actual n. With s_k < n, v_fmha layout only spans
first s_k V tokens and TMA reads OOB on later tiles.

Fix 2: TMA producer indexes K and V by kt (loop variable), NOT by the
pipeline's interleaved count. The kv pipeline interleaves K and V, so
pipeline count goes 0,1,2,3 but GMEM tiles should be K[0],V[0],K[1],V[1].

Fix 3: Online O rescale before softmax_done_bar. When row_max grows,
O must be multiplied by exp2(old_max - new_max) before MMA starts next PV.
2026-05-22 15:41:14 +00:00
5a08b79364 Revert "debug: test 12w identity softmax with n=256 to verify multi-tile pipeline"
This reverts commit 6cf8702e3c.
2026-05-22 10:25:48 +00:00
6cf8702e3c debug: test 12w identity softmax with n=256 to verify multi-tile pipeline 2026-05-22 10:24:53 +00:00
a3c9af8fa3 debug: disable O rescaling to test multi-tile pipeline baseline 2026-05-22 10:23:37 +00:00
c175ec4f09 fix: revert to scaled row_max, use exp2(old_max - new_max) for O rescaling
row_max is in scaled domain (s_val * scale_log2). The O rescaling
should be exp2(old_max - new_max) without extra scale_log2 because
the max values already include the scaling factor.
2026-05-22 10:22:44 +00:00
35c8043064 fix: compute row_max from RAW S values, not scaled
row_max should be the max of the raw QK scores, not pre-scaled.
The scale_log2 is applied during exp2 and rescaling, not stored in row_max.
This fixes the double-scaling bug that broke multi-tile O rescaling.
2026-05-22 10:21:50 +00:00
f9f5647eaa fix: missing newline after self.s_k = s_k 2026-05-22 10:20:35 +00:00
e0c320929a fix: add s_k param to FmhaV3StageC, use self.s_k for V FMHA reconstruction 2026-05-22 10:19:49 +00:00
fb4ffd8cf7 Stage C: add online O rescaling for multi-tile KV + test n=256
- Move O TMEM load/store setup before softmax loop
- After P store: rescale O in TMEM by exp2((old_max - new_max) * scale)
- Only rescale for kt > 0 (first tile has no prior O to rescale)
- Use same TMEM load/modify/store pattern as final normalization
- Test both n=128 (1 tile) and n=256 (2 tiles)
2026-05-22 10:19:08 +00:00
94b0d97107 fix: add epilogue warp to tmem_bar, restore wait_for_alloc in epilogue
The epilogue needs tmem_ptr for epilogue_tma_store.
It must be part of the tmem alloc barrier to synchronize.
2026-05-22 10:17:02 +00:00
65e52f5934 fix: add softmax_done_bar to synchronize MMA PV with softmax P production
MMA must wait for softmax to produce P in TMEM before starting PV.
Without this, MMA reads stale P data from TMEM, causing deadlock.
softmax_done_bar: softmax warps arrive after P store, MMA waits before PV.
2026-05-22 10:15:26 +00:00
ea687980af fix: epilogue warp self-signals acc_pipe producer before consuming 2026-05-22 10:11:55 +00:00
19b742f365 fix: remove duplicate tmem free from epilogue (MMA warp handles dealloc) 2026-05-22 10:05:52 +00:00
0a3815049f fix: add acc_pipe pipeline for epilogue, matching 12w pattern
- Add acc_bar to SS struct
- Create acc_pipe (full pipeline) before if blocks
- Pass acc_pipe to epilogue_tma_store (needs full pipeline, not participant)
2026-05-22 10:03:08 +00:00
59f4d8a469 fix: epilogue_warp_id must be tuple for epilogue_tma_store, check with [0] 2026-05-22 09:59:20 +00:00
6ba12b7890 fix: epilogue warp reuse mma_corr_cons pipeline instead of creating new one from st 2026-05-22 09:56:18 +00:00
540399eca3 fix: define cS and tScS in correction warps (not visible across if blocks) 2026-05-22 09:52:59 +00:00
ee859099bd fix: correct @cute.kernel indentation 2026-05-22 09:49:36 +00:00
fc7a790fbd fix: remove duplicate @cute.kernel decorator 2026-05-22 09:46:09 +00:00
78aac51ab9 FMHA Stage-C2: production 12-warp pipeline with correction warps
- Softmax warps (0-3): S→softmax→P, vec=[old_max,new_max]→TMEM
- Correction warps (4-7): O rescale in TMEM, final normalize by row_sum
- MMA warp (8): QK→S, PV→O with pipeline chaining
- TMA warp (9): Q/K/V load
- Epilogue warp (10): O TMEM→GMEM via epilogue_tma_store
- Empty warp (11): tmem dealloc mbar init
- Pipeline: mma_s→softmax→s_corr→correction→corr_epi→epilogue + mma_corr→correction
- Supports multi-tile KV with online O rescale
- Follows CUTLASS FMHA correction_rescale pattern exactly
2026-05-22 09:42:39 +00:00
c82c1ddc1b test: add multiple seeds to verify softmax consistency 2026-05-22 09:32:08 +00:00
a24b3e75a2 fix: use plain range loop for row_max (fmax not allowed in vectorized) 2026-05-22 09:31:07 +00:00
c96454d70b fix: add missing old_row_max = row_max before softmax max computation 2026-05-22 09:30:32 +00:00
aa9c2d2308 fix vectorize issue: remove vectorize from exp2 pass, add row_sum accumulation
- Remove vectorize=True from exp2 computation loop (carry variable)
- Add row_sum accumulation from P values in exp2 pass
- Compute row_max via fmax in separate pass
2026-05-22 09:29:43 +00:00
f631ff16d6 fix: use cute.arch.fmax instead of if-else in vectorized loop 2026-05-22 09:28:32 +00:00
941bcae8e1 softmax: element-wise row_max computation instead of .reduce()
The .reduce() on the C-fragment gives global max across all rows,
not per-row max. Compute row_max element-wise from S values before
the exp2 pass. Also accumulate row_sum in the exp2 pass.
2026-05-22 09:27:36 +00:00
5e51b726ba fix O normalization: use direct rmem tensor from partition_D shape 2026-05-22 09:23:58 +00:00
0da960d8da FMHA Stage-C: real softmax + O normalization in 6-warp layout
- Replace identity softmax with online softmax (row_max, exp2 scaling, P store)
- Add row_sum accumulation from P values
- After softmax loop, normalize O in TMEM by 1/row_sum using TMEM load/modify/store
- Then epilogue writes normalized O from TMEM to GMEM
- Reference test uses softmax(Q@K^T/sqrt(d))@V
2026-05-22 09:22:56 +00:00
6ebccf1e7e fix: use make_smem_layout_epi not make_epilogue_smem_layout 2026-05-22 09:19:12 +00:00
208af3eadd FMHA v3 Stage-C full: 12-warp pipeline with real softmax + correction + epilogue
- Softmax warps (0-3): online row max, exp2 scaling, P store, vec broadcast
- Correction warps (4-7): online O rescale, final normalization, SMEM write
- MMA warp (8): QK->S, PV->O with proper pipeline chaining
- TMA warp (9): Q/K/V load
- Epilogue warp (10): TMA store O from SMEM to GMEM
- Empty warp (11): tmem dealloc mbar init
- Pipeline chain: mma_s -> softmax -> s_corr -> correction -> corr_epi -> epilogue
- Plus mma_corr -> correction for O rescale
- Reference test uses softmax(Q@K^T/sqrt(d))@V
2026-05-22 09:18:56 +00:00
b81ed1924b more stuff 2026-05-22 08:57:38 +00:00
7e1ba2b525 FMHA v3: per-row min test + explicit loop replacements
- test_fmha_v3_per_row_min.py: minimal per-row test (no C6/C9, no barriers)
  Still hangs — likely CuTe DSL issue with logical_divide + explicit loops
- Replaced .load().reduce() on sliced tensors with explicit loops
- Very long compilation times suggest CuTe DSL is struggling

Key conclusion: per-row fix requires correction warp group.
The 6-warp code cant bridge 4 QK rows to 1 PV row per thread.
Need 128 correction threads (1 per output row) reading TMEM vector.
2026-05-22 07:29:04 +00:00
791bdc53a0 FMHA v3: per-row patch from Mike + deadlock fix + V layout fix
- test_fmha_v3_per_row.py: Mike's per-row patch with deadlock fix
  (moved C6 O-rescale after softmax_done_bar, fixed pv_done_bar for kt=0)
  Still GPU hangs — needs further debugging
- test_fmha_v3_fixed_v.py: s_k parameter + acc_pipe consumer fix
  Same cosine as original (V TMA handles data shape correctly)
- Baseline: n=128→0.993, n=256→0.725, n=384→0.620

Key insight: QK TMEM load fragment has 4 rows × 32 cols per thread.
Fragment-level row_max/row_sum is wrong for per-row operations.
Per-row tracking (4 separate row_max/row_sum per thread) is needed.
2026-05-22 07:09:52 +00:00
4761931c3e FMHA v3: add debug variants for C9 normalization investigation
- test_fmha_v3_scalar: direct acc_scale for C6 O-rescale (no vector)
- test_fmha_v3_vec_c9: TMEM vector for C9 row_sum transfer
- test_fmha_v3_noop_c9: hardcoded inv_row_sum=1.0 (no normalization)
- test_fmha_v3_debug: row_sum-based C9 normalization
- test_fmha_v3_proper: 11-warp correction warp group (in progress)

Key findings:
- QK and PV C-fragments map threads to same logical rows
- pv_row_sum (PV-based P read) gives cosine 0.993 for n=128
- row_sum (QK-accumulated) gives cosine 0.514 for n=128
- Noop (inv_row_sum=1.0) gives cosine 0.866 for n=128
- pv_row_sum is NOT 1.0 - it corrects PV MMA accumulator errors
- The C9 normalization is essential even for single-tile case
2026-05-22 05:52:10 +00:00
23abfe9845 KV Cache: schema, allocator, pools, manager, append_swa kernel
Complete KV cache substrate for DSV4 inference:

schema.py: Per-layer cache shape derived from LayerSpec.
  - CSA: 32 entries/block, 32 indexer entries, tail=3
  - HCA: 1 entry/block, no indexer, tail=127
  - SWA: no classical pool, no tail
  - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios)
  - compute_block_budget() for allocator sizing

allocator.py: Fixed-size block free-list.
  - GPU stack with pinned host top pointer
  - acquire/release between graph captures only
  - OOM raises on exhaustion

paged_cache.py: Per-layer classical KV storage.
  - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4)
  - Per-entry inverse scale for FP8 dequant
  - FP4 indexer keys for CSA layers (NVFP4 scheme)
  - memory_bytes() tracking

state_cache.py: Per-layer SWA window + tail buffer.
  - Ring buffer with position tracking (swa_head, swa_pos)
  - CSA: dual streams (ka/za/kb/zb) for overlapping compression
  - HCA: single stream (ka/za only)
  - SWA: no tail buffer
  - reset_slot() for request completion

handle.py: LayerCacheHandle — typed per-call view.
  - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view()
  - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe)
  - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures

manager.py: KVCacheManager — owns everything.
  - Per-layer schema, pool, and allocator construction
  - admit_request()/release_request() lifecycle
  - allocate_block() for compression flush
  - acquire() returns LayerCacheHandle (zero-alloc)

append_swa.cu: Native kernel for SWA writes.
  - One block per token, 128 threads per block
  - Warp-level amax reduction, BF16->FP8 E4M3 quantization
  - Atomic ring buffer head increment
  - FP8/BF16 split write + inv_scale + position metadata
  - FP8 round-trip: <3.6% relative error
  - RoPE half: exact match (no quantization)

All tests pass on B200:
  - Schema correctness for CSA/HCA/SWA
  - Allocator acquire/release/OOM
  - Pool shapes match architecture spec
  - Manager lifecycle (admit/release/recycle/exhaustion)
  - Zero-alloc acquire() (cudagraph safe)
  - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
39c1592d9c Clean up: remove debug/temp files and dangling test kernels 2026-05-21 23:26:50 +00:00
b034c915d1 10-warp debug: MMA=warp4 TMA=warp5 idle=6-9 still gives cosine 0.29
Pipeline init uses __syncthreads (all 320 threads participate).
Pipeline groups match 6-warp exactly.
Only difference: threads_per_cta=320 vs 192.

Direct comparison: 6-warp output [15,-129,-77.5,65,59]
vs 10-warp output [-7.5,2.2,-22.7,7.3,12.0] for row 0.
Completely different values.

Something in CuTe DSL runtime uses blockDim.x or total CTA size
in a way that breaks computation when CTA size changes from 192 to 320.

The pipeline_init_wait calls agent_sync(ThreadBlock) = __syncthreads
which all 320 threads reach. NamedBarriers use specific thread counts.
TMA atoms are created from MMA thread layout, not CTA size.

Hypothesis: the PipelineTmaUmma or PipelineUmmaAsync internally
uses blockDim.x for barrier arithmetic, making the barriers expect
more participants than the actual working threads.
2026-05-21 23:24:44 +00:00
0b8f4da323 Layer dispatch: config, schedule, attention/FFN sub-blocks, TransformerLayer
DSV4Config: frozen dataclass with .flash() / .pro() classmethods.
All architectural constants (dims, heads, MoE params, mHC) in one place.

LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode).
  Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers)
  Pro:   HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers)
  Both:  first 3 MoE layers = hash routing, rest = dense
  validate_schedule() enforces correctness at construction.

AttentionSubBlock: CSA / HCA / SWA variants.
  - Low-rank Q projection (q_down -> q_up)
  - KV down-projection (varies by attn type: 4h/2h/1h)
  - CSA: indexer_q_up + indexer_head_weights
  - Grouped output projection (wo_a + wo_b)
  - Kernel calls are imports (NotImplementedError until kernel lands)
  - No PyTorch fallback paths

FFNSubBlock: MoE + shared expert.
  - Router (hash/dense) mode from LayerSpec
  - Nvfp4MoE + Nvfp4SharedExpert

TransformerLayer: composition of mHC + norm + attention + FFN.
  - Two mHC wrappers (attn + ffn sub-blocks)
  - Two RMSNorm (one per sub-block)
  - Pure orchestration, no learned params on the layer itself

Tests: schedule construction + validation for both variants.
No forward tests yet (depends on FMHA kernel + KV cache).
2026-05-21 23:11:09 +00:00
c681b591a0 10-warp idle test: no crash but cosine 0.29 (6-warp gives 0.999999)
Adding 4 idle warps (4-7) to 320-thread CTA:
- No crash, no deadlock (idle warps just pass)
- But output is garbage: cosine 0.29 vs 0.999999

Same softmax+MMA code, same TMEM layout, same barriers.
Only difference: mma_warp_id=8 (was 4), threads_per_cta=320 (was 192)
and 4 idle warps 4-7.

Something in the pipeline/barrier system assumes the old 6-warp topology.
Need to identify which component uses threads_per_cta or warp_idx
in a way that breaks with more warps.
2026-05-21 22:07:53 +00:00
fb243a4133 Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls

Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer

Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store

Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path

Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing

Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)

Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00
a4d12fd560 WIP: correction warp group architecture - compiles, illegal address at runtime
4 softmax warps (0-3), 4 correction warps (4-7), 1 MMA (8), 1 TMA (9).
320 threads total.

Softmax: QK→softmax, write P, write row metadata to TMEM vector.
Correction: read vector via QK partition, rescale O (C6), normalize O (C9).

Compiles successfully but hits CUDA_ERROR_ILLEGAL_ADDRESS at runtime.
Likely: vector TMEM offsets or correction TMEM access layout is wrong.

Key files:
- tests/unit/test_fmha_v3_correction.py (new correction architecture)
- tests/unit/test_fmha_v3_softmax.py (working n=128, cosine 0.993)
2026-05-21 21:20:39 +00:00
bb3ad3d2ef BREAKTHROUGH: cosine 0.993 for n=128! PV-partitioned P row sum works.
C9 fix: instead of using QK-partitioned row_sum (which maps to wrong PV rows),
read P from TMEM using PV partition and sum via .reduce(ADD).

QK: thread N owns row N//4, PV: thread N owns row N.
Reading P via PV partition gives each thread its correct row P values.

n=128: cosine 0.993 (was 0.514)
n=256: cosine 0.725 (C6 still broken for multi-tile)
n=384: cosine 0.676 (same C6 issue)

Remaining: C6 O-rescale for multi-tile needs same PV-partitioned fix.
Small accuracy gap (0.993 vs 0.999) likely from BF16 P store/load round-trip.
2026-05-21 20:13:51 +00:00
7d1c402a6d WIP: TMEM vector bridge not working (same cosine 0.513)
row_sum is PROVEN correct (29.25 vs 29.22 for row 0, ratio 1.001).
The ONLY bug is QK→PV row mapping in C9 normalization.

Tried: composition(tStS,(128,1)) for write, composition(tOtO,(128,1)) for read.
Same result — the composition preserves the fragments internal thread-to-address
mapping, so the same thread writes and reads the same TMEM address regardless
of which fragment layout is used for the composition.

Need: absolute row-coordinate indexed TMEM vector. Each QK thread writes
inv_row_sum to vec[QK_row_id], each PV thread reads from vec[PV_row_id].
The row_id comes from the identity tensor coordinate.

Alternative: implement FMHA correction_epilog pattern with dedicated
correction warp group that reads row metadata from the vector.
2026-05-21 19:26:15 +00:00
cae87fd744 WIP: confirmed row_sum is wrong (5.5 vs correct 29.22 for row 0)
The packed f32x2 reduction SHOULD sum all 128 exp2 P values but gives
a result ~5.3x too small. Need to debug inside the kernel with print
statements to see what values the reduction is actually summing.

Unnormalized P@V is perfect (cosine 0.999998). row_max is correct
(because P is correct). The bug is specifically in row_sum computation.
2026-05-21 19:16:15 +00:00