Commit Graph

221 Commits

Author SHA1 Message Date
594878101d FIX: consistent GMEM/SMEM slicing for K and V TMA partitions
Both GMEM and SMEM sides must be sliced to the same rank for cute.copy.
K (QK MMA B-partition): slice [(None,None,0,0)] keeps modes 0,1
  - mode 1 = GMEM iteration, indexed by kvh.count
V (PV MMA B-partition): slice [(None,0,None,0)] keeps modes 0,2
  - mode 2 = GMEM iteration, indexed by kvh.count
Q: only 1 tile, (None,0,None,0) hardcode is fine.
2026-05-22 16:56:38 +00:00
0065c0180a FIX: keep GMEM iteration dimension FREE in TMA K/V partition slices
Root cause of multi-tile failure: (None,0,None,0) slice hardcodes the
GMEM tile dimension to 0, so TMA always loads from tile 0 regardless
of kvh.count. K from QK MMA has GMEM iter at mode 1, V from PV MMA
has it at mode 2 (different layouts: K,D,L vs D,K,L).

Fix follows CUTLASS FMHA reference:
- K: tBgK[(None,None,None,0)] + tBgK[(None, kvh.count, None)]
- V: tVgV[(None,0,None,0)] + tVgV[(None, kvh.count)]
2026-05-22 16:51:57 +00:00
b7c3b6613c Add diagnostic test for multi-tile TMA pipeline (identity softmax) 2026-05-22 16:47:08 +00:00
3712fda6ce FIX: acc_scale was double-multiplying by scale_log2
row_max is already in log2(scaled) space (S * scale_log2), so
old_row_max - row_max_safe is the correct exponent for exp2.
The old code computed exp2(scale_log2 * (old_row_max - row_max_safe))
which is exp2(scale_log2^2 * (old_max_S - new_max_S)) — wrong.
2026-05-22 16:42:45 +00:00
8ac35a32d7 Stage C: integrate example3 multi-tile fixes into unit test
- Combined K+V barrier (one acquire per kt, kvh.count == kt)
- O rescale for kt > 0 (online softmax O correction)
- final_o_bar sync (MMA signals before producer_tail)
- s_k as constructor param (compile-time for V layout)
- kv_tx_bytes covers both K and V transfers
- Test covers n=128, 256, 512, 1024
2026-05-22 16:39:45 +00:00
3ddee26eff FMHA Stage-C multi-tile: combined K+V barrier, final_o_bar, acc_pipe producer
Key changes from Mike:
1. Combined K+V TMA barrier: one acquire per kt, both cute.copys share
   kvh.barrier. kvh.count naturally == kt (no interleaving problem).
   tx_count = K_bytes + V_bytes. Also fixes the sK[0]/sV[1] slot quirk.
2. final_o_bar NamedBarrier: MMA .arrive() after acc_pipe.producer_tail;
   softmax .arrive_and_wait() before reading O for normalize. Prevents
   softmax racing MMA's PV[N-1] on the final O read.
3. acc_pipe producer in MMA: producer_acquire before loop, commit+advance
   after loop, producer_tail after. Consumer in epilogue as before.
4. O rescale re-enabled for kt>0 with acc_scale before softmax_done_bar.
2026-05-22 16:23:36 +00:00
d83434384e restore tBgK to kh.count indexing (single-tile working), add TODO for multi-tile
CuTeDSL TMA copy API doesn't support dynamic GMEM tile indexing.
kh.count works for single tile. For multi-tile, need to either:
1. Map pipeline count to tile index (kh.count // 2 for interleaved K/V)
2. Separate K and V into non-interleaved TMA loops
3. Use gK/gV layouts that iterate naturally with pipeline count

This is the architectural blocker for multi-tile FMHA.
2026-05-22 15:54:03 +00:00
493d8e817d FIX: use unsliced tBgK with (None, kt, None, 0) for proper GMEM tile indexing
The pre-slice (None,0,None,0) hardcoded GMEM iteration to tile 0.
Instead, keep the original tBgK and index with (None, kt, None, 0)
inside the TMA loop, where kt selects the correct GMEM tile.
This preserves 2D rank matching with the SMEM tensor.
2026-05-22 15:52:56 +00:00
d3031eefd8 CRITICAL FIX: keep GMEM iteration dim free in tBgK/tVgV slice
The slice (None,0,None,0) was hardcoding the GMEM iteration dim to 0,
meaning TMA always loaded K/V from tile 0 regardless of kt.
Changed to (None,None,None,0) to keep gmem_iter free,
then index with (None, kt, None) in the TMA copy loop.

This is the root cause of multi-tile failure: TMA was always reading
the first 128 tokens for ALL KV tiles.
2026-05-22 15:52:06 +00:00
7b8ee862bd add explicit acc_pipe.consumer_wait before final normalize
Race condition: softmax reads O to normalize while MMA may still be
writing PV[N-1]. Single-tile wins by luck; multi-tile drifts.
Move acc_cons_st construction before the wait so epilogue reuses it.
2026-05-22 15:49:48 +00:00
1f489c3cc2 FMHA Stage-C multi-tile: Fix 1 (s_k=n), Fix 2 (TMA kt indexing), Fix 3 (O rescale)
Fix 1: s_k must equal actual n. With s_k < n, v_fmha layout only spans
first s_k V tokens and TMA reads OOB on later tiles.

Fix 2: TMA producer indexes K and V by kt (loop variable), NOT by the
pipeline's interleaved count. The kv pipeline interleaves K and V, so
pipeline count goes 0,1,2,3 but GMEM tiles should be K[0],V[0],K[1],V[1].

Fix 3: Online O rescale before softmax_done_bar. When row_max grows,
O must be multiplied by exp2(old_max - new_max) before MMA starts next PV.
2026-05-22 15:41:14 +00:00
f5931783d1 Revert "debug: test 12w identity softmax with n=256 to verify multi-tile pipeline"
This reverts commit 6cf8702e3c.
2026-05-22 10:25:48 +00:00
846b74da99 debug: test 12w identity softmax with n=256 to verify multi-tile pipeline 2026-05-22 10:24:53 +00:00
a18f43c8d4 debug: disable O rescaling to test multi-tile pipeline baseline 2026-05-22 10:23:37 +00:00
22c2cf35aa fix: revert to scaled row_max, use exp2(old_max - new_max) for O rescaling
row_max is in scaled domain (s_val * scale_log2). The O rescaling
should be exp2(old_max - new_max) without extra scale_log2 because
the max values already include the scaling factor.
2026-05-22 10:22:44 +00:00
3cd907b66b fix: compute row_max from RAW S values, not scaled
row_max should be the max of the raw QK scores, not pre-scaled.
The scale_log2 is applied during exp2 and rescaling, not stored in row_max.
This fixes the double-scaling bug that broke multi-tile O rescaling.
2026-05-22 10:21:50 +00:00
1738ae8cfd fix: missing newline after self.s_k = s_k 2026-05-22 10:20:35 +00:00
ade32615f0 fix: add s_k param to FmhaV3StageC, use self.s_k for V FMHA reconstruction 2026-05-22 10:19:49 +00:00
648f216b7c Stage C: add online O rescaling for multi-tile KV + test n=256
- Move O TMEM load/store setup before softmax loop
- After P store: rescale O in TMEM by exp2((old_max - new_max) * scale)
- Only rescale for kt > 0 (first tile has no prior O to rescale)
- Use same TMEM load/modify/store pattern as final normalization
- Test both n=128 (1 tile) and n=256 (2 tiles)
2026-05-22 10:19:08 +00:00
36437ed15b fix: add epilogue warp to tmem_bar, restore wait_for_alloc in epilogue
The epilogue needs tmem_ptr for epilogue_tma_store.
It must be part of the tmem alloc barrier to synchronize.
2026-05-22 10:17:02 +00:00
02d5d9e5d0 fix: add softmax_done_bar to synchronize MMA PV with softmax P production
MMA must wait for softmax to produce P in TMEM before starting PV.
Without this, MMA reads stale P data from TMEM, causing deadlock.
softmax_done_bar: softmax warps arrive after P store, MMA waits before PV.
2026-05-22 10:15:26 +00:00
4b3f723bd9 fix: epilogue warp self-signals acc_pipe producer before consuming 2026-05-22 10:11:55 +00:00
59fc652012 fix: remove duplicate tmem free from epilogue (MMA warp handles dealloc) 2026-05-22 10:05:52 +00:00
8d70de5370 fix: add acc_pipe pipeline for epilogue, matching 12w pattern
- Add acc_bar to SS struct
- Create acc_pipe (full pipeline) before if blocks
- Pass acc_pipe to epilogue_tma_store (needs full pipeline, not participant)
2026-05-22 10:03:08 +00:00
c6076a1c57 fix: epilogue_warp_id must be tuple for epilogue_tma_store, check with [0] 2026-05-22 09:59:20 +00:00
ad8865eb73 fix: epilogue warp reuse mma_corr_cons pipeline instead of creating new one from st 2026-05-22 09:56:18 +00:00
b149c310ac fix: define cS and tScS in correction warps (not visible across if blocks) 2026-05-22 09:52:59 +00:00
d706a83902 fix: correct @cute.kernel indentation 2026-05-22 09:49:36 +00:00
e6ae915a2e fix: remove duplicate @cute.kernel decorator 2026-05-22 09:46:09 +00:00
e4cde73562 FMHA Stage-C2: production 12-warp pipeline with correction warps
- Softmax warps (0-3): S→softmax→P, vec=[old_max,new_max]→TMEM
- Correction warps (4-7): O rescale in TMEM, final normalize by row_sum
- MMA warp (8): QK→S, PV→O with pipeline chaining
- TMA warp (9): Q/K/V load
- Epilogue warp (10): O TMEM→GMEM via epilogue_tma_store
- Empty warp (11): tmem dealloc mbar init
- Pipeline: mma_s→softmax→s_corr→correction→corr_epi→epilogue + mma_corr→correction
- Supports multi-tile KV with online O rescale
- Follows CUTLASS FMHA correction_rescale pattern exactly
2026-05-22 09:42:39 +00:00
406bcbc20e test: add multiple seeds to verify softmax consistency 2026-05-22 09:32:08 +00:00
db2ee08e37 fix: use plain range loop for row_max (fmax not allowed in vectorized) 2026-05-22 09:31:07 +00:00
3aa597460e fix: add missing old_row_max = row_max before softmax max computation 2026-05-22 09:30:32 +00:00
ce06478e56 fix vectorize issue: remove vectorize from exp2 pass, add row_sum accumulation
- Remove vectorize=True from exp2 computation loop (carry variable)
- Add row_sum accumulation from P values in exp2 pass
- Compute row_max via fmax in separate pass
2026-05-22 09:29:43 +00:00
f1687ba3b8 fix: use cute.arch.fmax instead of if-else in vectorized loop 2026-05-22 09:28:32 +00:00
54eb16deef softmax: element-wise row_max computation instead of .reduce()
The .reduce() on the C-fragment gives global max across all rows,
not per-row max. Compute row_max element-wise from S values before
the exp2 pass. Also accumulate row_sum in the exp2 pass.
2026-05-22 09:27:36 +00:00
1052c310b0 fix O normalization: use direct rmem tensor from partition_D shape 2026-05-22 09:23:58 +00:00
ec4b837522 FMHA Stage-C: real softmax + O normalization in 6-warp layout
- Replace identity softmax with online softmax (row_max, exp2 scaling, P store)
- Add row_sum accumulation from P values
- After softmax loop, normalize O in TMEM by 1/row_sum using TMEM load/modify/store
- Then epilogue writes normalized O from TMEM to GMEM
- Reference test uses softmax(Q@K^T/sqrt(d))@V
2026-05-22 09:22:56 +00:00
d791361fc2 fix: use make_smem_layout_epi not make_epilogue_smem_layout 2026-05-22 09:19:12 +00:00
cfc86c6b1b FMHA v3 Stage-C full: 12-warp pipeline with real softmax + correction + epilogue
- Softmax warps (0-3): online row max, exp2 scaling, P store, vec broadcast
- Correction warps (4-7): online O rescale, final normalization, SMEM write
- MMA warp (8): QK->S, PV->O with proper pipeline chaining
- TMA warp (9): Q/K/V load
- Epilogue warp (10): TMA store O from SMEM to GMEM
- Empty warp (11): tmem dealloc mbar init
- Pipeline chain: mma_s -> softmax -> s_corr -> correction -> corr_epi -> epilogue
- Plus mma_corr -> correction for O rescale
- Reference test uses softmax(Q@K^T/sqrt(d))@V
2026-05-22 09:18:56 +00:00
ff9e8aab6a more stuff 2026-05-22 08:57:38 +00:00
3fe5805ae9 FMHA v3: per-row min test + explicit loop replacements
- test_fmha_v3_per_row_min.py: minimal per-row test (no C6/C9, no barriers)
  Still hangs — likely CuTe DSL issue with logical_divide + explicit loops
- Replaced .load().reduce() on sliced tensors with explicit loops
- Very long compilation times suggest CuTe DSL is struggling

Key conclusion: per-row fix requires correction warp group.
The 6-warp code cant bridge 4 QK rows to 1 PV row per thread.
Need 128 correction threads (1 per output row) reading TMEM vector.
2026-05-22 07:29:04 +00:00
b9d7bf4db3 FMHA v3: per-row patch from Mike + deadlock fix + V layout fix
- test_fmha_v3_per_row.py: Mike's per-row patch with deadlock fix
  (moved C6 O-rescale after softmax_done_bar, fixed pv_done_bar for kt=0)
  Still GPU hangs — needs further debugging
- test_fmha_v3_fixed_v.py: s_k parameter + acc_pipe consumer fix
  Same cosine as original (V TMA handles data shape correctly)
- Baseline: n=128→0.993, n=256→0.725, n=384→0.620

Key insight: QK TMEM load fragment has 4 rows × 32 cols per thread.
Fragment-level row_max/row_sum is wrong for per-row operations.
Per-row tracking (4 separate row_max/row_sum per thread) is needed.
2026-05-22 07:09:52 +00:00
b1dda5cdc9 FMHA v3: add debug variants for C9 normalization investigation
- test_fmha_v3_scalar: direct acc_scale for C6 O-rescale (no vector)
- test_fmha_v3_vec_c9: TMEM vector for C9 row_sum transfer
- test_fmha_v3_noop_c9: hardcoded inv_row_sum=1.0 (no normalization)
- test_fmha_v3_debug: row_sum-based C9 normalization
- test_fmha_v3_proper: 11-warp correction warp group (in progress)

Key findings:
- QK and PV C-fragments map threads to same logical rows
- pv_row_sum (PV-based P read) gives cosine 0.993 for n=128
- row_sum (QK-accumulated) gives cosine 0.514 for n=128
- Noop (inv_row_sum=1.0) gives cosine 0.866 for n=128
- pv_row_sum is NOT 1.0 - it corrects PV MMA accumulator errors
- The C9 normalization is essential even for single-tile case
2026-05-22 05:52:10 +00:00
d315ceda82 KV Cache: schema, allocator, pools, manager, append_swa kernel
Complete KV cache substrate for DSV4 inference:

schema.py: Per-layer cache shape derived from LayerSpec.
  - CSA: 32 entries/block, 32 indexer entries, tail=3
  - HCA: 1 entry/block, no indexer, tail=127
  - SWA: no classical pool, no tail
  - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios)
  - compute_block_budget() for allocator sizing

allocator.py: Fixed-size block free-list.
  - GPU stack with pinned host top pointer
  - acquire/release between graph captures only
  - OOM raises on exhaustion

paged_cache.py: Per-layer classical KV storage.
  - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4)
  - Per-entry inverse scale for FP8 dequant
  - FP4 indexer keys for CSA layers (NVFP4 scheme)
  - memory_bytes() tracking

state_cache.py: Per-layer SWA window + tail buffer.
  - Ring buffer with position tracking (swa_head, swa_pos)
  - CSA: dual streams (ka/za/kb/zb) for overlapping compression
  - HCA: single stream (ka/za only)
  - SWA: no tail buffer
  - reset_slot() for request completion

handle.py: LayerCacheHandle — typed per-call view.
  - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view()
  - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe)
  - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures

manager.py: KVCacheManager — owns everything.
  - Per-layer schema, pool, and allocator construction
  - admit_request()/release_request() lifecycle
  - allocate_block() for compression flush
  - acquire() returns LayerCacheHandle (zero-alloc)

append_swa.cu: Native kernel for SWA writes.
  - One block per token, 128 threads per block
  - Warp-level amax reduction, BF16->FP8 E4M3 quantization
  - Atomic ring buffer head increment
  - FP8/BF16 split write + inv_scale + position metadata
  - FP8 round-trip: <3.6% relative error
  - RoPE half: exact match (no quantization)

All tests pass on B200:
  - Schema correctness for CSA/HCA/SWA
  - Allocator acquire/release/OOM
  - Pool shapes match architecture spec
  - Manager lifecycle (admit/release/recycle/exhaustion)
  - Zero-alloc acquire() (cudagraph safe)
  - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
64b0ded719 Clean up: remove debug/temp files and dangling test kernels 2026-05-21 23:26:50 +00:00
e58517b221 10-warp debug: MMA=warp4 TMA=warp5 idle=6-9 still gives cosine 0.29
Pipeline init uses __syncthreads (all 320 threads participate).
Pipeline groups match 6-warp exactly.
Only difference: threads_per_cta=320 vs 192.

Direct comparison: 6-warp output [15,-129,-77.5,65,59]
vs 10-warp output [-7.5,2.2,-22.7,7.3,12.0] for row 0.
Completely different values.

Something in CuTe DSL runtime uses blockDim.x or total CTA size
in a way that breaks computation when CTA size changes from 192 to 320.

The pipeline_init_wait calls agent_sync(ThreadBlock) = __syncthreads
which all 320 threads reach. NamedBarriers use specific thread counts.
TMA atoms are created from MMA thread layout, not CTA size.

Hypothesis: the PipelineTmaUmma or PipelineUmmaAsync internally
uses blockDim.x for barrier arithmetic, making the barriers expect
more participants than the actual working threads.
2026-05-21 23:24:44 +00:00
5a63604d6a Layer dispatch: config, schedule, attention/FFN sub-blocks, TransformerLayer
DSV4Config: frozen dataclass with .flash() / .pro() classmethods.
All architectural constants (dims, heads, MoE params, mHC) in one place.

LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode).
  Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers)
  Pro:   HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers)
  Both:  first 3 MoE layers = hash routing, rest = dense
  validate_schedule() enforces correctness at construction.

AttentionSubBlock: CSA / HCA / SWA variants.
  - Low-rank Q projection (q_down -> q_up)
  - KV down-projection (varies by attn type: 4h/2h/1h)
  - CSA: indexer_q_up + indexer_head_weights
  - Grouped output projection (wo_a + wo_b)
  - Kernel calls are imports (NotImplementedError until kernel lands)
  - No PyTorch fallback paths

FFNSubBlock: MoE + shared expert.
  - Router (hash/dense) mode from LayerSpec
  - Nvfp4MoE + Nvfp4SharedExpert

TransformerLayer: composition of mHC + norm + attention + FFN.
  - Two mHC wrappers (attn + ffn sub-blocks)
  - Two RMSNorm (one per sub-block)
  - Pure orchestration, no learned params on the layer itself

Tests: schedule construction + validation for both variants.
No forward tests yet (depends on FMHA kernel + KV cache).
2026-05-21 23:11:09 +00:00
553f0c30c7 10-warp idle test: no crash but cosine 0.29 (6-warp gives 0.999999)
Adding 4 idle warps (4-7) to 320-thread CTA:
- No crash, no deadlock (idle warps just pass)
- But output is garbage: cosine 0.29 vs 0.999999

Same softmax+MMA code, same TMEM layout, same barriers.
Only difference: mma_warp_id=8 (was 4), threads_per_cta=320 (was 192)
and 4 idle warps 4-7.

Something in the pipeline/barrier system assumes the old 6-warp topology.
Need to identify which component uses threads_per_cta or warp_idx
in a way that breaks with more warps.
2026-05-21 22:07:53 +00:00
fbc5f6f9da Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls

Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer

Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store

Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path

Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing

Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)

Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00